Unverified Commit f8eb799a authored by Emmanuel Ferdman's avatar Emmanuel Ferdman Committed by GitHub
Browse files

[PyTorch] remove duplicate code (#1215)


Signed-off-by: default avatarEmmanuel Ferdman <emmanuelferdman@gmail.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 9d976bcd
...@@ -1854,13 +1854,6 @@ def _run_ref_mha_f16(dtype, config, backend): ...@@ -1854,13 +1854,6 @@ def _run_ref_mha_f16(dtype, config, backend):
"""Get cuda rng tracker.""" """Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER return _DUMMY_CUDA_RNG_STATE_TRACKER
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
block = DotProductAttention( block = DotProductAttention(
config.num_heads, config.num_heads,
config.head_dim_qk, config.head_dim_qk,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment