Unverified Commit 0d2021ef authored by Jeng Bai-Cheng's avatar Jeng Bai-Cheng Committed by GitHub
Browse files

[JAX] bugfix for softmax lowering (#218)



bugfix for softmax lowering
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>
parent 8d4761ad
......@@ -382,14 +382,14 @@ class TestEncoder(unittest.TestCase):
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.49 and actual[1] > 0.76
assert actual[0] < 0.50 and actual[1] > 0.76
@unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
self.args.use_fp8 = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.49 and actual[1] > 0.76
assert actual[0] < 0.50 and actual[1] > 0.76
if __name__ == "__main__":
......
......@@ -39,6 +39,8 @@ def te_dtype_to_jax_dtype(te_dtype):
return jnp.bfloat16
if te_dtype == TEDType.kInt32:
return jnp.int32
if te_dtype == TEDType.kInt64:
return jnp.int64
return jnp.int8
......@@ -1677,7 +1679,7 @@ class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive):
grad_outputs, softmax_outputs,
scale_factor)
return [out]
return out # out is iterable already
_scaled_softmax_bwd_p = register_primitive(ScaledSoftmaxBwdPrimitive)
......@@ -1826,7 +1828,7 @@ class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
grad_outputs, softmax_outputs,
scale_factor)
return [out]
return out # out is iterable already
_scaled_masked_softmax_bwd_p = register_primitive(ScaledMaskedSoftmaxBwdPrimitive)
......@@ -1960,7 +1962,7 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
ScaledUpperTriangMaskedSoftmaxBwdPrimitive.name, ctx, grad_outputs, softmax_outputs,
scale_factor)
return [out]
return out # out is iterable already
_scaled_upper_triang_masked_softmax_bwd_p = \
register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)
......
......@@ -69,7 +69,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte)
.value("kInt32", DType::kInt32)
.value("KInt64", DType::kInt64)
.value("kInt64", DType::kInt64)
.value("kFloat32", DType::kFloat32)
.value("kFloat16", DType::kFloat16)
.value("kBFloat16", DType::kBFloat16)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment