"vscode:/vscode.git/clone" did not exist on "293c7906e10c4baa41616b6a7b3d3b88ddad6e53"
Unverified Commit 3bcd7f6f authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Add tests for current scaling; misc related fixes (#1606)



* Cleanup sanity tests and add CS recipe tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix sanity test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix CG capture with CS recipe
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix ops for CG
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 0356010c
...@@ -54,6 +54,7 @@ model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)} ...@@ -54,6 +54,7 @@ model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)}
fp8_recipes = [ fp8_recipes = [
recipe.DelayedScaling(), recipe.DelayedScaling(),
recipe.MXFP8BlockScaling(), recipe.MXFP8BlockScaling(),
recipe.Float8CurrentScaling(),
] ]
# Supported data types # Supported data types
......
...@@ -103,32 +103,17 @@ model_configs = { ...@@ -103,32 +103,17 @@ model_configs = {
} }
fp8_recipes = [ fp8_recipes = [
None, # Handles non-FP8 case None, # Test non-FP8
recipe.MXFP8BlockScaling(), recipe.MXFP8BlockScaling(), # Test default
recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3), recipe.Float8CurrentScaling(), # Test default
recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID), recipe.DelayedScaling(), # Test default
recipe.DelayedScaling( recipe.DelayedScaling( # Test most_recent algo
margin=0,
fp8_format=recipe.Format.E4M3,
amax_history_len=16, amax_history_len=16,
amax_compute_algo="most_recent", amax_compute_algo="most_recent",
), ),
recipe.DelayedScaling( recipe.DelayedScaling( # Test custom amax and scale compute algo
margin=0,
fp8_format=recipe.Format.E4M3,
amax_history_len=16,
amax_compute_algo="max",
),
recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.E4M3, fp8_format=recipe.Format.E4M3,
amax_history_len=16,
amax_compute_algo=custom_amax_compute, amax_compute_algo=custom_amax_compute,
),
recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.E4M3,
amax_history_len=16,
scaling_factor_compute_algo=custom_amax_to_scale, scaling_factor_compute_algo=custom_amax_to_scale,
), ),
] ]
...@@ -560,6 +545,8 @@ def test_sanity_grouped_linear( ...@@ -560,6 +545,8 @@ def test_sanity_grouped_linear(
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8(): if fp8_recipe.mxfp8():
pytest.skip("Grouped linear does not support MXFP8") pytest.skip("Grouped linear does not support MXFP8")
if fp8_recipe.float8_current_scaling():
pytest.skip("Grouped linear does not support FP8 current scaling")
if not config.is_fp8_supported(): if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
......
...@@ -197,8 +197,9 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf ...@@ -197,8 +197,9 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf
max_fp8 = Quantized_Limits<DType>::max_norm;); max_fp8 = Quantized_Limits<DType>::max_norm;);
// Update scale // Update scale
compute_scale_from_amax_kernel<<<1, 1>>>(reinterpret_cast<const float *>(output.amax.dptr), compute_scale_from_amax_kernel<<<1, 1, 0, stream>>>(
reinterpret_cast<float *>(output.scale.dptr), max_fp8, reinterpret_cast<const float *>(output.amax.dptr),
config.force_pow_2_scales, config.amax_epsilon); reinterpret_cast<float *>(output.scale.dptr), max_fp8, config.force_pow_2_scales,
config.amax_epsilon);
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
...@@ -283,7 +283,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -283,7 +283,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
recipe_state = fp8_meta[fp8_meta_key] recipe_state = fp8_meta[fp8_meta_key]
# Reallocate amax history if needed # Reallocate amax history if needed
if recipe.mxfp8(): if not recipe.delayed():
continue continue
current_length = recipe_state.amax_history.size(0) current_length = recipe_state.amax_history.size(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment