Unverified Commit 5e80b2a7 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Avoid using torch.compile for roll and fill_ (#609)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent b4b8ae7b
...@@ -583,7 +583,7 @@ def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: ...@@ -583,7 +583,7 @@ def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor:
return amax_history return amax_history
@jit_fuser @torch.jit.script
def _default_get_amax( def _default_get_amax(
amax_history: torch.Tensor, amax_history: torch.Tensor,
amax_compute_algo: str, amax_compute_algo: str,
...@@ -625,7 +625,7 @@ def _compute_scaling_factor_inverse( ...@@ -625,7 +625,7 @@ def _compute_scaling_factor_inverse(
return torch.where(non_weight_mask, 1.0 / scale, scale_inv) return torch.where(non_weight_mask, 1.0 / scale, scale_inv)
@jit_fuser @torch.jit.script
def _fused_amax_and_scale_update( def _fused_amax_and_scale_update(
amax_history: torch.Tensor, amax_history: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment