Commit 12a4c631 authored by Aubrey Li's avatar Aubrey Li
Browse files

Fix TypeError when invoke KLinearCPUInfer.forward()

Fix the following error:

  File "/home/aubrey/work/ktransformers/ktransformers/operators/linear.py", line 825, in forward
    y = self.generate_linear.forward(x, bsz_tensor)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: KLinearCPUInfer.forward() takes 2 positional arguments but 3 were given
parent 6ca743ed
......@@ -699,7 +699,7 @@ class KLinearCPUInfer(KLinearBase):
self.group_max_len = group_max_len
self.out_device = out_device
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Tensor:
origin_shape = x.shape # [batch_size, q_len, hidden_size]
if origin_shape[1] == 1 and torch.cuda.is_current_stream_capturing():
out_device = x.device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment