Unverified Commit 77c6cc82 authored by Azure's avatar Azure Committed by GitHub
Browse files

Merge pull request #1063 from aubreyli/KLinearCPUInfer.forward-fix

Fix TypeError when invoke KLinearCPUInfer.forward()
parents 6463070b 12a4c631
...@@ -699,7 +699,7 @@ class KLinearCPUInfer(KLinearBase): ...@@ -699,7 +699,7 @@ class KLinearCPUInfer(KLinearBase):
self.group_max_len = group_max_len self.group_max_len = group_max_len
self.out_device = out_device self.out_device = out_device
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Tensor:
origin_shape = x.shape # [batch_size, q_len, hidden_size] origin_shape = x.shape # [batch_size, q_len, hidden_size]
if origin_shape[1] == 1 and torch.cuda.is_current_stream_capturing(): if origin_shape[1] == 1 and torch.cuda.is_current_stream_capturing():
out_device = x.device out_device = x.device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment