Unverified Commit 95a36eae authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[kernel] added kernel loader to softmax autograd function (#3093)

* [kernel] added kernel loader to softmax autograd function

* [release] v0.2.6
parent fff98f06
...@@ -180,4 +180,9 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -180,4 +180,9 @@ class FusedScaleMaskSoftmax(nn.Module):
return probs return probs
def get_batch_per_block(self, sq, sk, b, np): def get_batch_per_block(self, sq, sk, b, np):
# build and load kernel if not pre-built
global scaled_masked_softmax
if scaled_masked_softmax is None:
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()
return scaled_masked_softmax.get_batch_per_block(sq, sk, b, np) return scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment