Unverified Commit 793a54bf authored by rybakov's avatar rybakov Committed by GitHub
Browse files

remove code duplication in a test (#915)


Signed-off-by: default avatarOleg Rybakov <orybakov@nvidia.com>
Co-authored-by: default avatarOleg Rybakov <orybakov@nvidia.com>
parent 43678153
...@@ -328,6 +328,57 @@ def get_attn_mask_str(use_mask, attn_mask_type): ...@@ -328,6 +328,57 @@ def get_attn_mask_str(use_mask, attn_mask_type):
return attn_mask_str return attn_mask_str
class FP8GemmModule(nn.Module):
def __init__(self, precision, use_bias, gelu, scale_factors, hidden_size, out_features):
super().__init__()
self.use_bias = use_bias
self.gelu = gelu
self.precision = precision
self.fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT
self.fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT
nb_inp_scales, nb_weight_scales = 1, out_features
act_scale_factor, weight_scale_factor = scale_factors
self.meta_inp = create_meta(act_scale_factor, nb_inp_scales)
self.meta_weight = create_meta(weight_scale_factor, nb_weight_scales)
bias_size = nb_weight_scales
self.bias = torch.randn(bias_size, dtype=precision, device="cuda")
self.gelu_input = torch.randn(hidden_size, out_features, dtype=precision, device="cuda")
self.inp_type = tex.DType.kFloat8E4M3
self.weights_type = tex.DType.kFloat8E4M3
self.outp_type = precision
def forward(self, inp, weight):
inp_fp8 = cast_to_fp8(
inp,
self.meta_inp,
self.fp8_tensor_inp,
self.inp_type)
weight_fp8 = cast_to_fp8(
weight,
self.meta_weight,
self.fp8_tensor_weight,
self.weights_type)
ret, _ = fp8_gemm(
weight_fp8,
self.meta_weight.scale_inv,
self.fp8_tensor_weight,
self.inp_type,
inp_fp8,
self.meta_inp.scale_inv,
self.fp8_tensor_inp,
self.weights_type,
self.outp_type,
get_workspace(),
bias=self.bias,
use_bias=self.use_bias,
use_split_accumulator=False)
return ret
""" """
Tests cases begin here. Tests cases begin here.
""" """
...@@ -477,57 +528,6 @@ def test_export_gemm( ...@@ -477,57 +528,6 @@ def test_export_gemm(
if use_fp8 and not fp8_available: if use_fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
class TestFP8_GEMM(nn.Module):
def __init__(self, precision, use_bias, gelu, scale_factors):
super().__init__()
self.use_bias = use_bias
self.gelu = gelu
self.precision = precision
self.fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT
self.fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT
nb_inp_scales, nb_weight_scales = 1, out_features
act_scale_factor, weight_scale_factor = scale_factors
self.meta_inp = create_meta(act_scale_factor, nb_inp_scales)
self.meta_weight = create_meta(weight_scale_factor, nb_weight_scales)
bias_size = nb_weight_scales
self.bias = torch.randn(bias_size, dtype=precision, device="cuda")
self.gelu_input = torch.randn(hidden_size, out_features, dtype=precision, device="cuda")
self.inp_type = tex.DType.kFloat8E4M3
self.weights_type = tex.DType.kFloat8E4M3
self.outp_type = precision
def forward(self, inp, weight):
inp_fp8 = cast_to_fp8(
inp,
self.meta_inp,
self.fp8_tensor_inp,
self.inp_type)
weight_fp8 = cast_to_fp8(
weight,
self.meta_weight,
self.fp8_tensor_weight,
self.weights_type)
ret, _ = fp8_gemm(
weight_fp8,
self.meta_weight.scale_inv,
self.fp8_tensor_weight,
self.inp_type,
inp_fp8,
self.meta_inp.scale_inv,
self.fp8_tensor_inp,
self.weights_type,
self.outp_type,
get_workspace(),
bias=self.bias,
use_bias=self.use_bias,
use_split_accumulator=False)
return ret
class Test_GEMM(nn.Module): class Test_GEMM(nn.Module):
def __init__(self, precision, use_bias=False, gelu=False): def __init__(self, precision, use_bias=False, gelu=False):
super().__init__() super().__init__()
...@@ -576,7 +576,7 @@ def test_export_gemm( ...@@ -576,7 +576,7 @@ def test_export_gemm(
fname = f"te.gemm{fp8_str}{bias_str}{gelu_str}{high_prec_str}.onnx" fname = f"te.gemm{fp8_str}{bias_str}{gelu_str}{high_prec_str}.onnx"
input_names = ['input', 'weight'] input_names = ['input', 'weight']
if use_fp8: if use_fp8:
model = TestFP8_GEMM(precision, use_bias, use_gelu, scale_factors) model = FP8GemmModule(precision, use_bias, use_gelu, scale_factors, hidden_size, out_features)
do_export(model, (inp, weight), fname, use_fp8, input_names=input_names) do_export(model, (inp, weight), fname, use_fp8, input_names=input_names)
te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8) te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8)
serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names) serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names)
...@@ -1253,6 +1253,9 @@ def test_export_gemm_layernorm( ...@@ -1253,6 +1253,9 @@ def test_export_gemm_layernorm(
The test sets up GEMM with FP32 output which feeds into an LN that is configured The test sets up GEMM with FP32 output which feeds into an LN that is configured
with FP16 or BF16 weights and bias. with FP16 or BF16 weights and bias.
""" """
out_features = 128
hidden_size = 128
in_features = 128
# Skip FP8 tests on non-hopper devices # Skip FP8 tests on non-hopper devices
if use_fp8 and not fp8_available: if use_fp8 and not fp8_available:
...@@ -1268,8 +1271,9 @@ def test_export_gemm_layernorm( ...@@ -1268,8 +1271,9 @@ def test_export_gemm_layernorm(
self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT
self.meta = create_meta(ln_scale_factor) self.meta = create_meta(ln_scale_factor)
self.fp8_type = tex.DType.kFloat8E4M3 self.fp8_type = tex.DType.kFloat8E4M3
self.gemm = TestFP8_GEMM( self.gemm = FP8GemmModule(
precision, use_bias=False, gelu=False, scale_factors=gemm_scale_factors) precision, use_bias=False, gelu=False, scale_factors=gemm_scale_factors,
hidden_size=hidden_size, out_features=out_features)
def forward(self, inp, weight): def forward(self, inp, weight):
x = self.gemm(inp, weight) x = self.gemm(inp, weight)
...@@ -1292,60 +1296,6 @@ def test_export_gemm_layernorm( ...@@ -1292,60 +1296,6 @@ def test_export_gemm_layernorm(
tex.DType.kFloat32 if precision == torch.float32 else tex.DType.kFloat16) tex.DType.kFloat32 if precision == torch.float32 else tex.DType.kFloat16)
return x return x
out_features = 128
hidden_size = 128
in_features = 128
class TestFP8_GEMM(nn.Module):
def __init__(self, precision, use_bias, gelu, scale_factors):
super().__init__()
self.use_bias = use_bias
self.gelu = gelu
self.precision = precision
self.fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT
self.fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT
nb_inp_scales, nb_weight_scales = 1, out_features
act_scale_factor, weight_scale_factor = scale_factors
self.meta_inp = create_meta(act_scale_factor, nb_inp_scales)
self.meta_weight = create_meta(weight_scale_factor, nb_weight_scales)
bias_size = nb_weight_scales
self.bias = torch.randn(bias_size, dtype=precision, device="cuda")
self.gelu_input = torch.randn(hidden_size, out_features, dtype=precision, device="cuda")
self.inp_type = tex.DType.kFloat8E4M3
self.weights_type = tex.DType.kFloat8E4M3
self.outp_type = precision
def forward(self, inp, weight):
inp_fp8 = cast_to_fp8(
inp,
self.meta_inp,
self.fp8_tensor_inp,
self.inp_type)
weight_fp8 = cast_to_fp8(
weight,
self.meta_weight,
self.fp8_tensor_weight,
self.weights_type)
ret, _ = fp8_gemm(
weight_fp8,
self.meta_weight.scale_inv,
self.fp8_tensor_weight,
self.inp_type,
inp_fp8,
self.meta_inp.scale_inv,
self.fp8_tensor_inp,
self.weights_type,
self.outp_type,
get_workspace(),
bias=self.bias,
use_bias=self.use_bias,
use_split_accumulator=False)
return ret
inp = torch.randn(hidden_size, in_features, dtype=precision, device="cuda") inp = torch.randn(hidden_size, in_features, dtype=precision, device="cuda")
weight = torch.randn(out_features, in_features, dtype=precision, device="cuda") weight = torch.randn(out_features, in_features, dtype=precision, device="cuda")
model = TestFP8_GemmLayernorm() model = TestFP8_GemmLayernorm()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment