Unverified Commit 3bcd7f6f authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Add tests for current scaling; misc related fixes (#1606)



* Cleanup sanity tests and add CS recipe tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix sanity test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix CG capture with CS recipe
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix ops for CG
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 0356010c
......@@ -54,6 +54,7 @@ model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)}
fp8_recipes = [
recipe.DelayedScaling(),
recipe.MXFP8BlockScaling(),
recipe.Float8CurrentScaling(),
]
# Supported data types
......
......@@ -103,32 +103,17 @@ model_configs = {
}
fp8_recipes = [
None, # Handles non-FP8 case
recipe.MXFP8BlockScaling(),
recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3),
recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID),
recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.E4M3,
None, # Test non-FP8
recipe.MXFP8BlockScaling(), # Test default
recipe.Float8CurrentScaling(), # Test default
recipe.DelayedScaling(), # Test default
recipe.DelayedScaling( # Test most_recent algo
amax_history_len=16,
amax_compute_algo="most_recent",
),
recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.E4M3,
amax_history_len=16,
amax_compute_algo="max",
),
recipe.DelayedScaling(
margin=0,
recipe.DelayedScaling( # Test custom amax and scale compute algo
fp8_format=recipe.Format.E4M3,
amax_history_len=16,
amax_compute_algo=custom_amax_compute,
),
recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.E4M3,
amax_history_len=16,
scaling_factor_compute_algo=custom_amax_to_scale,
),
]
......@@ -560,6 +545,8 @@ def test_sanity_grouped_linear(
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8():
pytest.skip("Grouped linear does not support MXFP8")
if fp8_recipe.float8_current_scaling():
pytest.skip("Grouped linear does not support FP8 current scaling")
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
......
......@@ -197,8 +197,9 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf
max_fp8 = Quantized_Limits<DType>::max_norm;);
// Update scale
compute_scale_from_amax_kernel<<<1, 1>>>(reinterpret_cast<const float *>(output.amax.dptr),
reinterpret_cast<float *>(output.scale.dptr), max_fp8,
config.force_pow_2_scales, config.amax_epsilon);
compute_scale_from_amax_kernel<<<1, 1, 0, stream>>>(
reinterpret_cast<const float *>(output.amax.dptr),
reinterpret_cast<float *>(output.scale.dptr), max_fp8, config.force_pow_2_scales,
config.amax_epsilon);
NVTE_CHECK_CUDA(cudaGetLastError());
}
......@@ -283,7 +283,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
recipe_state = fp8_meta[fp8_meta_key]
# Reallocate amax history if needed
if recipe.mxfp8():
if not recipe.delayed():
continue
current_length = recipe_state.amax_history.size(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment