Commit b11d6fca authored by tabuchixiangcai3's avatar tabuchixiangcai3
Browse files

[DCU]Fix the original code


Signed-off-by: tabuchixiangcai3's avatarTangao <2205747538@qq.com>
parent 2a64c9a6
......@@ -16,6 +16,7 @@ import transformer_engine
import transformer_engine_torch as tex
import nvdlfw_inspect.api as debug_api
from transformer_engine.debug import set_weight_tensor_tp_group_reduce
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from test_numerics import (
......@@ -45,6 +46,7 @@ FEATURE_DIRS = None
all_boolean = [True, False]
TEST_NR = 0
fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None, tp_rank=None):
if tp_size is None:
......@@ -231,7 +233,7 @@ def run_debug_test(func):
return wrapper
CONFIG_LOG_TEST_DISTRIBUTED = """log_distributed:
CONFIG_LOG_TEST_DISTRIBUTED_FP8 = """log_distributed:
layers:
layer_types: [linear]
enabled:
......@@ -251,11 +253,27 @@ CONFIG_LOG_TEST_DISTRIBUTED = """log_distributed:
end_step: 1
"""
CONFIG_LOG_TEST_DISTRIBUTED_NO_FP8 = """log_distributed:
layers:
layer_types: [linear]
enabled:
True
transformer_engine:
LogTensorStats:
enabled: True
tensors: [activation, gradient, weight, output, wgrad, dgrad]
stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range]
start_step : 0
end_step: 1
"""
def _prepare_config_test_log_distributed(config_file):
if WORLD_RANK != 0:
return
config_file.write(CONFIG_LOG_TEST_DISTRIBUTED)
config_file.write(
CONFIG_LOG_TEST_DISTRIBUTED_FP8 if fp8_available else CONFIG_LOG_TEST_DISTRIBUTED_NO_FP8
)
config_file.flush()
......@@ -355,6 +373,39 @@ def test_log_distributed(parallel_mode, gather_weight, **kwargs):
)
set_weight_tensor_tp_group_reduce(True) # reset
@run_debug_test
def sanity_test_log_quantized_stats(parallel_mode, gather_weight, **kwargs):
from test_log import LOG_QUANTIZED_CONFIG
kwargs["config_file"].write(LOG_QUANTIZED_CONFIG)
kwargs["config_file"].flush()
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS)
set_weight_tensor_tp_group_reduce(gather_weight)
if WORLD_SIZE % 2 != 0:
return # skip
TP_SIZE = WORLD_SIZE // 2
DP_SIZE = 2
TP_RANK = WORLD_RANK % TP_SIZE
DP_RANK = (WORLD_RANK - TP_RANK) // TP_SIZE
debug_api.set_tensor_reduction_group(NCCL_WORLD)
x, weight = _get_tensors(
parallel_mode,
weight_seed=TP_RANK * 1234,
data_seed=DP_RANK * 1234,
tp_size=TP_SIZE,
tp_rank=TP_RANK,
)
tp_group_ranks = [i for i in range(DP_RANK * TP_SIZE, (DP_RANK + 1) * TP_SIZE)]
tp_group = dist.new_group(ranks=tp_group_ranks)
model = _init_model(weight, parallel_mode=parallel_mode, tp_group=tp_group)
_run_forward_backward(x, model, parallel_mode=parallel_mode, group=tp_group)
set_weight_tensor_tp_group_reduce(True) # reset
@run_debug_test
def test_log_expert_parallel(**kwargs):
......@@ -371,13 +422,13 @@ def test_log_expert_parallel(**kwargs):
) # data parallel
model = _init_model(weight, parallel_mode=None, name="linear1")
model1 = _init_model(weight, parallel_mode=None, name="linear2")
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
with transformer_engine.pytorch.fp8_autocast(enabled=fp8_available, fp8_recipe=FP8_RECIPE):
y1 = model(x)
y2 = model1(x)
y = y1 + y2
y.sum().backward()
debug_api.step()
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
with transformer_engine.pytorch.fp8_autocast(enabled=fp8_available, fp8_recipe=FP8_RECIPE):
y = model(x)
if WORLD_RANK != 0:
y = y + model1(x)
......@@ -637,9 +688,11 @@ if __name__ == "__main__":
for gather_weight in [True, False]:
test_log_distributed(parallel_mode, gather_weight)
if fp8_available:
for parallel_mode in ["row", "column"]:
test_disable_fp8_layer(parallel_mode)
if IS_HIP_EXTENSION:
# Output type 32 (FP32) does not support int8 simulation.
pass
......
......@@ -517,6 +517,7 @@ def test_linear():
{"return_bias": True},
{"params_dtype": torch.float16},
{"delay_wgrad_compute": True},
{"save_original_input": True},
]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
......@@ -528,6 +529,8 @@ def test_linear():
else:
kwargs_list = base_kwargs_list
for kwargs in kwargs_list:
if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8":
continue
for parallel_mode in ["column", "row"]:
for sequence_parallel in [False, True]:
_test_linear(parallel_mode, sequence_parallel, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment