Commit 3b081313 authored by dongcl's avatar dongcl
Browse files

add description for quantized communication

parent 9b681ba5
...@@ -77,6 +77,12 @@ def unpermute( ...@@ -77,6 +77,12 @@ def unpermute(
``` ```
+ 项目支持通过split-bw进行dw拆分,用于实现更好的overlap。当前从测试结果看,开启split-bw,效果欠佳,待进一步优化。 + 项目支持通过split-bw进行dw拆分,用于实现更好的overlap。当前从测试结果看,开启split-bw,效果欠佳,待进一步优化。
### 项目支持量化通信
+ 项目支持量化通信,对all-to-all通信数据进行低精度表示,减少通信量。如果需要使用该特性,需要启动脚本中加入如下参数:
```
--use-quantize-comm
```
## 使用方式 ## 使用方式
### 项目下载 ### 项目下载
...@@ -97,6 +103,7 @@ def unpermute( ...@@ -97,6 +103,7 @@ def unpermute(
2.3 将Megatron-LM离线代码包解压到dcu_megatron目录下的Megatron-LM目录 2.3 将Megatron-LM离线代码包解压到dcu_megatron目录下的Megatron-LM目录
### 项目使用 ### 项目使用
在使用时,进入到examples目录下,有相关模型执行脚本,所用数据集请自行下载:https://r0ddbu55vzx.feishu.cn/drive/folder/ZxHHfCoX4lg75td2hTqcmiAin3g 在使用时,进入到examples目录下,有相关模型执行脚本,所用数据集请自行下载:https://r0ddbu55vzx.feishu.cn/drive/folder/ZxHHfCoX4lg75td2hTqcmiAin3g
``` ```
......
...@@ -5,12 +5,12 @@ from .qcomm import q_alltoall ...@@ -5,12 +5,12 @@ from .qcomm import q_alltoall
class _AllToAll(torch.autograd.Function): class _AllToAll(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, group, input, output_split_sizes, input_split_sizes, use_qcomm=False): def forward(ctx, group, input, output_split_sizes, input_split_sizes, use_quantize_comm=False):
"""Forward function.""" """Forward function."""
ctx.group = group ctx.group = group
ctx.output_split_sizes = output_split_sizes ctx.output_split_sizes = output_split_sizes
ctx.input_split_sizes = input_split_sizes ctx.input_split_sizes = input_split_sizes
ctx.use_qcomm = use_qcomm ctx.use_quantize_comm = use_quantize_comm
world_size = torch.distributed.get_world_size(group=group) world_size = torch.distributed.get_world_size(group=group)
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
...@@ -20,7 +20,7 @@ class _AllToAll(torch.autograd.Function): ...@@ -20,7 +20,7 @@ class _AllToAll(torch.autograd.Function):
input = input.contiguous() input = input.contiguous()
if output_split_sizes is None: if output_split_sizes is None:
# Equal split (all2all) # Equal split (all2all)
if use_qcomm: if use_quantize_comm:
output = input.new_empty( output = input.new_empty(
size=[input.shape[0], input.shape[1]+4], size=[input.shape[0], input.shape[1]+4],
dtype=torch.int8, dtype=torch.int8,
...@@ -30,7 +30,7 @@ class _AllToAll(torch.autograd.Function): ...@@ -30,7 +30,7 @@ class _AllToAll(torch.autograd.Function):
output = torch.empty_like(input) output = torch.empty_like(input)
else: else:
# Unequal split (all2all-v) # Unequal split (all2all-v)
if use_qcomm: if use_quantize_comm:
output = input.new_empty( output = input.new_empty(
size=[sum(output_split_sizes)] + list(input.size()[1:]), size=[sum(output_split_sizes)] + list(input.size()[1:]),
dtype=torch.int8, dtype=torch.int8,
...@@ -43,7 +43,7 @@ class _AllToAll(torch.autograd.Function): ...@@ -43,7 +43,7 @@ class _AllToAll(torch.autograd.Function):
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
) )
if use_qcomm: if use_quantize_comm:
output = q_alltoall(output, input, output_split_sizes, input_split_sizes,group) output = q_alltoall(output, input, output_split_sizes, input_split_sizes,group)
else: else:
torch.distributed.all_to_all_single( torch.distributed.all_to_all_single(
...@@ -60,13 +60,13 @@ class _AllToAll(torch.autograd.Function): ...@@ -60,13 +60,13 @@ class _AllToAll(torch.autograd.Function):
"""Backward function.""" """Backward function."""
return ( return (
None, None,
_AllToAll.apply(ctx.group, *grad_output, ctx.input_split_sizes, ctx.output_split_sizes, ctx.use_qcomm), _AllToAll.apply(ctx.group, *grad_output, ctx.input_split_sizes, ctx.output_split_sizes, ctx.use_quantize_comm),
None, None,
None, None,
None, None,
) )
def all_to_all(group, input_, output_split_sizes_=None, input_split_sizes=None, use_qcomm=False): def all_to_all(group, input_, output_split_sizes_=None, input_split_sizes=None, use_quantize_comm=False):
"""Wrapper for autograd function""" """Wrapper for autograd function"""
return _AllToAll.apply(group, input_, output_split_sizes_, input_split_sizes, use_qcomm) return _AllToAll.apply(group, input_, output_split_sizes_, input_split_sizes, use_quantize_comm)
...@@ -40,9 +40,9 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -40,9 +40,9 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# use_qcomm # use_quantize_comm
args = get_args() args = get_args()
self.use_qcomm = args.use_qcomm self.use_quantize_comm = args.use_quantize_comm
def collect_per_batch_state(self, state: MoEAlltoAllPerBatchState): def collect_per_batch_state(self, state: MoEAlltoAllPerBatchState):
state.num_global_tokens_per_local_expert = getattr( state.num_global_tokens_per_local_expert = getattr(
...@@ -134,7 +134,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -134,7 +134,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
"before_ep_alltoall", tokens_per_expert "before_ep_alltoall", tokens_per_expert
) )
global_input_tokens = all_to_all( global_input_tokens = all_to_all(
self.ep_group, permutated_local_input_tokens, self.output_splits, self.input_splits, use_qcomm=self.use_qcomm self.ep_group, permutated_local_input_tokens, self.output_splits, self.input_splits, use_quantize_comm=self.use_quantize_comm
) )
return tokens_per_expert, global_input_tokens return tokens_per_expert, global_input_tokens
...@@ -258,7 +258,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -258,7 +258,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
# Perform expert parallel AlltoAll communication # Perform expert parallel AlltoAll communication
# hidden_states: [SEQL, H] -> [SEQL, H/TP] # hidden_states: [SEQL, H] -> [SEQL, H/TP]
permutated_local_input_tokens = all_to_all( permutated_local_input_tokens = all_to_all(
self.ep_group, hidden_states, self.input_splits, self.output_splits, use_qcomm=self.use_qcomm self.ep_group, hidden_states, self.input_splits, self.output_splits, use_quantize_comm=self.use_quantize_comm
) )
return permutated_local_input_tokens return permutated_local_input_tokens
......
...@@ -129,7 +129,7 @@ def _add_extra_tokenizer_args(parser): ...@@ -129,7 +129,7 @@ def _add_extra_tokenizer_args(parser):
'NullTokenizer', 'NullTokenizer',
'DeepSeekV2Tokenizer'], 'DeepSeekV2Tokenizer'],
help='What type of tokenizer to use.') help='What type of tokenizer to use.')
group.add_argument('--use-qcomm', group.add_argument('--use-quantize-comm',
default=False, default=False,
action="store_true", action="store_true",
help='use quantized communication') help='use quantized communication')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment