Unverified Commit d668f18f authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[Pytorch] Added missing assert_dim_for_fp8_exec for Linear



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* reshape inp
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
parent 7f2dcf91
...@@ -27,6 +27,7 @@ from ..utils import ( ...@@ -27,6 +27,7 @@ from ..utils import (
divide, divide,
init_method_constant, init_method_constant,
non_tn_fp8_gemm_supported, non_tn_fp8_gemm_supported,
assert_dim_for_fp8_exec,
nvtx_range_pop, nvtx_range_pop,
nvtx_range_push, nvtx_range_push,
requires_grad, requires_grad,
...@@ -118,13 +119,14 @@ class _Linear(torch.autograd.Function): ...@@ -118,13 +119,14 @@ class _Linear(torch.autograd.Function):
# Prepare input tensor # Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication # Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push(f"{nvtx_label}.input_cast_comm") nvtx_range_push(f"{nvtx_label}.input_cast_comm")
inputmat = inp inputmat = inp.view(-1, in_features)
inputmat_total = None inputmat_total = None
with_input_all_gather_nccl = ( with_input_all_gather_nccl = (
parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
) )
own_quantized_input = False own_quantized_input = False
if fp8: if fp8:
assert_dim_for_fp8_exec(inputmat, weight)
if ( if (
any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) any([ub_overlap_ag_fprop, ub_overlap_rs_fprop])
and not FP8GlobalStateManager.get_fp8_recipe().delayed() and not FP8GlobalStateManager.get_fp8_recipe().delayed()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment