Commit 0d350c8d authored by hyunwoongko's avatar hyunwoongko Committed by mshoeybi
Browse files

fix bugs in fused softmax

parent 116820a5
...@@ -186,7 +186,7 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -186,7 +186,7 @@ class FusedScaleMaskSoftmax(nn.Module):
return probs return probs
@staticmethod @staticmethod
def get_batch_per_block(b, np, sq, sk): def get_batch_per_block(sq, sk, b, np):
import scaled_masked_softmax_cuda import scaled_masked_softmax_cuda
return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np) return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment