Commit 0d350c8d authored by hyunwoongko's avatar hyunwoongko Committed by mshoeybi
Browse files

fix bugs in fused softmax

parent 116820a5
......@@ -186,7 +186,7 @@ class FusedScaleMaskSoftmax(nn.Module):
return probs
@staticmethod
def get_batch_per_block(b, np, sq, sk):
def get_batch_per_block(sq, sk, b, np):
import scaled_masked_softmax_cuda
return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment