Commit 63913f29 authored by LeeDongYeun's avatar LeeDongYeun Committed by Zhekai Zhang
Browse files

fix shape in GEMM W8A8

parent af6b1a3c
......@@ -451,9 +451,9 @@ GEMM_W8A8::QuantizedActivation GEMM_W8A8::quantize(Tensor x, bool fuse_glu) {
}
Tensor GEMM_W8A8::forward_quant(QuantizedActivation qact) {
auto oshape = qact.act.shape;
oshape[-1] = out_features;
Tensor out = Tensor::allocate(oshape, this->dtype, qact.act.device());
auto shape = TensorShape(qact.act.shape.dataExtent);
shape[-1] = out_features;
Tensor out = Tensor::allocate(shape, this->dtype, qact.act.device());
kernels::gemm_w8a8(qact.act, this->qweight, out, qact.ascales, this->wscales, this->bias);
debug("gemm.out", out);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment