Unverified Commit 793a54bf authored by rybakov's avatar rybakov Committed by GitHub
Browse files

remove code duplication in a test (#915)


Signed-off-by: default avatarOleg Rybakov <orybakov@nvidia.com>
Co-authored-by: default avatarOleg Rybakov <orybakov@nvidia.com>
parent 43678153
......@@ -328,6 +328,57 @@ def get_attn_mask_str(use_mask, attn_mask_type):
return attn_mask_str
class FP8GemmModule(nn.Module):
def __init__(self, precision, use_bias, gelu, scale_factors, hidden_size, out_features):
super().__init__()
self.use_bias = use_bias
self.gelu = gelu
self.precision = precision
self.fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT
self.fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT
nb_inp_scales, nb_weight_scales = 1, out_features
act_scale_factor, weight_scale_factor = scale_factors
self.meta_inp = create_meta(act_scale_factor, nb_inp_scales)
self.meta_weight = create_meta(weight_scale_factor, nb_weight_scales)
bias_size = nb_weight_scales
self.bias = torch.randn(bias_size, dtype=precision, device="cuda")
self.gelu_input = torch.randn(hidden_size, out_features, dtype=precision, device="cuda")
self.inp_type = tex.DType.kFloat8E4M3
self.weights_type = tex.DType.kFloat8E4M3
self.outp_type = precision
def forward(self, inp, weight):
inp_fp8 = cast_to_fp8(
inp,
self.meta_inp,
self.fp8_tensor_inp,
self.inp_type)
weight_fp8 = cast_to_fp8(
weight,
self.meta_weight,
self.fp8_tensor_weight,
self.weights_type)
ret, _ = fp8_gemm(
weight_fp8,
self.meta_weight.scale_inv,
self.fp8_tensor_weight,
self.inp_type,
inp_fp8,
self.meta_inp.scale_inv,
self.fp8_tensor_inp,
self.weights_type,
self.outp_type,
get_workspace(),
bias=self.bias,
use_bias=self.use_bias,
use_split_accumulator=False)
return ret
"""
Tests cases begin here.
"""
......@@ -477,57 +528,6 @@ def test_export_gemm(
if use_fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
class TestFP8_GEMM(nn.Module):
def __init__(self, precision, use_bias, gelu, scale_factors):
super().__init__()
self.use_bias = use_bias
self.gelu = gelu
self.precision = precision
self.fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT
self.fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT
nb_inp_scales, nb_weight_scales = 1, out_features
act_scale_factor, weight_scale_factor = scale_factors
self.meta_inp = create_meta(act_scale_factor, nb_inp_scales)
self.meta_weight = create_meta(weight_scale_factor, nb_weight_scales)
bias_size = nb_weight_scales
self.bias = torch.randn(bias_size, dtype=precision, device="cuda")
self.gelu_input = torch.randn(hidden_size, out_features, dtype=precision, device="cuda")
self.inp_type = tex.DType.kFloat8E4M3
self.weights_type = tex.DType.kFloat8E4M3
self.outp_type = precision
def forward(self, inp, weight):
inp_fp8 = cast_to_fp8(
inp,
self.meta_inp,
self.fp8_tensor_inp,
self.inp_type)
weight_fp8 = cast_to_fp8(
weight,
self.meta_weight,
self.fp8_tensor_weight,
self.weights_type)
ret, _ = fp8_gemm(
weight_fp8,
self.meta_weight.scale_inv,
self.fp8_tensor_weight,
self.inp_type,
inp_fp8,
self.meta_inp.scale_inv,
self.fp8_tensor_inp,
self.weights_type,
self.outp_type,
get_workspace(),
bias=self.bias,
use_bias=self.use_bias,
use_split_accumulator=False)
return ret
class Test_GEMM(nn.Module):
def __init__(self, precision, use_bias=False, gelu=False):
super().__init__()
......@@ -576,7 +576,7 @@ def test_export_gemm(
fname = f"te.gemm{fp8_str}{bias_str}{gelu_str}{high_prec_str}.onnx"
input_names = ['input', 'weight']
if use_fp8:
model = TestFP8_GEMM(precision, use_bias, use_gelu, scale_factors)
model = FP8GemmModule(precision, use_bias, use_gelu, scale_factors, hidden_size, out_features)
do_export(model, (inp, weight), fname, use_fp8, input_names=input_names)
te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8)
serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names)
......@@ -1253,6 +1253,9 @@ def test_export_gemm_layernorm(
The test sets up GEMM with FP32 output which feeds into an LN that is configured
with FP16 or BF16 weights and bias.
"""
out_features = 128
hidden_size = 128
in_features = 128
# Skip FP8 tests on non-hopper devices
if use_fp8 and not fp8_available:
......@@ -1268,8 +1271,9 @@ def test_export_gemm_layernorm(
self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT
self.meta = create_meta(ln_scale_factor)
self.fp8_type = tex.DType.kFloat8E4M3
self.gemm = TestFP8_GEMM(
precision, use_bias=False, gelu=False, scale_factors=gemm_scale_factors)
self.gemm = FP8GemmModule(
precision, use_bias=False, gelu=False, scale_factors=gemm_scale_factors,
hidden_size=hidden_size, out_features=out_features)
def forward(self, inp, weight):
x = self.gemm(inp, weight)
......@@ -1292,60 +1296,6 @@ def test_export_gemm_layernorm(
tex.DType.kFloat32 if precision == torch.float32 else tex.DType.kFloat16)
return x
out_features = 128
hidden_size = 128
in_features = 128
class TestFP8_GEMM(nn.Module):
def __init__(self, precision, use_bias, gelu, scale_factors):
super().__init__()
self.use_bias = use_bias
self.gelu = gelu
self.precision = precision
self.fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT
self.fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT
nb_inp_scales, nb_weight_scales = 1, out_features
act_scale_factor, weight_scale_factor = scale_factors
self.meta_inp = create_meta(act_scale_factor, nb_inp_scales)
self.meta_weight = create_meta(weight_scale_factor, nb_weight_scales)
bias_size = nb_weight_scales
self.bias = torch.randn(bias_size, dtype=precision, device="cuda")
self.gelu_input = torch.randn(hidden_size, out_features, dtype=precision, device="cuda")
self.inp_type = tex.DType.kFloat8E4M3
self.weights_type = tex.DType.kFloat8E4M3
self.outp_type = precision
def forward(self, inp, weight):
inp_fp8 = cast_to_fp8(
inp,
self.meta_inp,
self.fp8_tensor_inp,
self.inp_type)
weight_fp8 = cast_to_fp8(
weight,
self.meta_weight,
self.fp8_tensor_weight,
self.weights_type)
ret, _ = fp8_gemm(
weight_fp8,
self.meta_weight.scale_inv,
self.fp8_tensor_weight,
self.inp_type,
inp_fp8,
self.meta_inp.scale_inv,
self.fp8_tensor_inp,
self.weights_type,
self.outp_type,
get_workspace(),
bias=self.bias,
use_bias=self.use_bias,
use_split_accumulator=False)
return ret
inp = torch.randn(hidden_size, in_features, dtype=precision, device="cuda")
weight = torch.randn(out_features, in_features, dtype=precision, device="cuda")
model = TestFP8_GemmLayernorm()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment