Commit b11d6fca authored by tabuchixiangcai3's avatar tabuchixiangcai3
Browse files

[DCU]Fix the original code


Signed-off-by: tabuchixiangcai3's avatarTangao <2205747538@qq.com>
parent 2a64c9a6
...@@ -16,6 +16,7 @@ import transformer_engine ...@@ -16,6 +16,7 @@ import transformer_engine
import transformer_engine_torch as tex import transformer_engine_torch as tex
import nvdlfw_inspect.api as debug_api import nvdlfw_inspect.api as debug_api
from transformer_engine.debug import set_weight_tensor_tp_group_reduce from transformer_engine.debug import set_weight_tensor_tp_group_reduce
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
from test_numerics import ( from test_numerics import (
...@@ -45,6 +46,7 @@ FEATURE_DIRS = None ...@@ -45,6 +46,7 @@ FEATURE_DIRS = None
all_boolean = [True, False] all_boolean = [True, False]
TEST_NR = 0 TEST_NR = 0
fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None, tp_rank=None): def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None, tp_rank=None):
if tp_size is None: if tp_size is None:
...@@ -231,7 +233,7 @@ def run_debug_test(func): ...@@ -231,7 +233,7 @@ def run_debug_test(func):
return wrapper return wrapper
CONFIG_LOG_TEST_DISTRIBUTED = """log_distributed: CONFIG_LOG_TEST_DISTRIBUTED_FP8 = """log_distributed:
layers: layers:
layer_types: [linear] layer_types: [linear]
enabled: enabled:
...@@ -251,11 +253,27 @@ CONFIG_LOG_TEST_DISTRIBUTED = """log_distributed: ...@@ -251,11 +253,27 @@ CONFIG_LOG_TEST_DISTRIBUTED = """log_distributed:
end_step: 1 end_step: 1
""" """
CONFIG_LOG_TEST_DISTRIBUTED_NO_FP8 = """log_distributed:
layers:
layer_types: [linear]
enabled:
True
transformer_engine:
LogTensorStats:
enabled: True
tensors: [activation, gradient, weight, output, wgrad, dgrad]
stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range]
start_step : 0
end_step: 1
"""
def _prepare_config_test_log_distributed(config_file): def _prepare_config_test_log_distributed(config_file):
if WORLD_RANK != 0: if WORLD_RANK != 0:
return return
config_file.write(CONFIG_LOG_TEST_DISTRIBUTED) config_file.write(
CONFIG_LOG_TEST_DISTRIBUTED_FP8 if fp8_available else CONFIG_LOG_TEST_DISTRIBUTED_NO_FP8
)
config_file.flush() config_file.flush()
...@@ -355,6 +373,39 @@ def test_log_distributed(parallel_mode, gather_weight, **kwargs): ...@@ -355,6 +373,39 @@ def test_log_distributed(parallel_mode, gather_weight, **kwargs):
) )
set_weight_tensor_tp_group_reduce(True) # reset set_weight_tensor_tp_group_reduce(True) # reset
@run_debug_test
def sanity_test_log_quantized_stats(parallel_mode, gather_weight, **kwargs):
from test_log import LOG_QUANTIZED_CONFIG
kwargs["config_file"].write(LOG_QUANTIZED_CONFIG)
kwargs["config_file"].flush()
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS)
set_weight_tensor_tp_group_reduce(gather_weight)
if WORLD_SIZE % 2 != 0:
return # skip
TP_SIZE = WORLD_SIZE // 2
DP_SIZE = 2
TP_RANK = WORLD_RANK % TP_SIZE
DP_RANK = (WORLD_RANK - TP_RANK) // TP_SIZE
debug_api.set_tensor_reduction_group(NCCL_WORLD)
x, weight = _get_tensors(
parallel_mode,
weight_seed=TP_RANK * 1234,
data_seed=DP_RANK * 1234,
tp_size=TP_SIZE,
tp_rank=TP_RANK,
)
tp_group_ranks = [i for i in range(DP_RANK * TP_SIZE, (DP_RANK + 1) * TP_SIZE)]
tp_group = dist.new_group(ranks=tp_group_ranks)
model = _init_model(weight, parallel_mode=parallel_mode, tp_group=tp_group)
_run_forward_backward(x, model, parallel_mode=parallel_mode, group=tp_group)
set_weight_tensor_tp_group_reduce(True) # reset
@run_debug_test @run_debug_test
def test_log_expert_parallel(**kwargs): def test_log_expert_parallel(**kwargs):
...@@ -371,13 +422,13 @@ def test_log_expert_parallel(**kwargs): ...@@ -371,13 +422,13 @@ def test_log_expert_parallel(**kwargs):
) # data parallel ) # data parallel
model = _init_model(weight, parallel_mode=None, name="linear1") model = _init_model(weight, parallel_mode=None, name="linear1")
model1 = _init_model(weight, parallel_mode=None, name="linear2") model1 = _init_model(weight, parallel_mode=None, name="linear2")
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): with transformer_engine.pytorch.fp8_autocast(enabled=fp8_available, fp8_recipe=FP8_RECIPE):
y1 = model(x) y1 = model(x)
y2 = model1(x) y2 = model1(x)
y = y1 + y2 y = y1 + y2
y.sum().backward() y.sum().backward()
debug_api.step() debug_api.step()
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): with transformer_engine.pytorch.fp8_autocast(enabled=fp8_available, fp8_recipe=FP8_RECIPE):
y = model(x) y = model(x)
if WORLD_RANK != 0: if WORLD_RANK != 0:
y = y + model1(x) y = y + model1(x)
...@@ -637,9 +688,11 @@ if __name__ == "__main__": ...@@ -637,9 +688,11 @@ if __name__ == "__main__":
for gather_weight in [True, False]: for gather_weight in [True, False]:
test_log_distributed(parallel_mode, gather_weight) test_log_distributed(parallel_mode, gather_weight)
if fp8_available:
for parallel_mode in ["row", "column"]: for parallel_mode in ["row", "column"]:
test_disable_fp8_layer(parallel_mode) test_disable_fp8_layer(parallel_mode)
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
# Output type 32 (FP32) does not support int8 simulation. # Output type 32 (FP32) does not support int8 simulation.
pass pass
......
...@@ -517,6 +517,7 @@ def test_linear(): ...@@ -517,6 +517,7 @@ def test_linear():
{"return_bias": True}, {"return_bias": True},
{"params_dtype": torch.float16}, {"params_dtype": torch.float16},
{"delay_wgrad_compute": True}, {"delay_wgrad_compute": True},
{"save_original_input": True},
] ]
#TODO:The blockwise recipe does not currently support calculations with bias set to true. #TODO:The blockwise recipe does not currently support calculations with bias set to true.
""" """
...@@ -528,6 +529,8 @@ def test_linear(): ...@@ -528,6 +529,8 @@ def test_linear():
else: else:
kwargs_list = base_kwargs_list kwargs_list = base_kwargs_list
for kwargs in kwargs_list: for kwargs in kwargs_list:
if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8":
continue
for parallel_mode in ["column", "row"]: for parallel_mode in ["column", "row"]:
for sequence_parallel in [False, True]: for sequence_parallel in [False, True]:
_test_linear(parallel_mode, sequence_parallel, **kwargs) _test_linear(parallel_mode, sequence_parallel, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment