Commit 3b081313 authored by dongcl's avatar dongcl
Browse files

add description for quantized communication

parent 9b681ba5
......@@ -77,6 +77,12 @@ def unpermute(
```
+ 项目支持通过split-bw进行dw拆分,用于实现更好的overlap。当前从测试结果看,开启split-bw,效果欠佳,待进一步优化。
### 项目支持量化通信
+ 项目支持量化通信,对all-to-all通信数据进行低精度表示,减少通信量。如果需要使用该特性,需要启动脚本中加入如下参数:
```
--use-quantize-comm
```
## 使用方式
### 项目下载
......@@ -97,6 +103,7 @@ def unpermute(
2.3 将Megatron-LM离线代码包解压到dcu_megatron目录下的Megatron-LM目录
### 项目使用
在使用时,进入到examples目录下,有相关模型执行脚本,所用数据集请自行下载:https://r0ddbu55vzx.feishu.cn/drive/folder/ZxHHfCoX4lg75td2hTqcmiAin3g
```
......
......@@ -5,12 +5,12 @@ from .qcomm import q_alltoall
class _AllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx, group, input, output_split_sizes, input_split_sizes, use_qcomm=False):
def forward(ctx, group, input, output_split_sizes, input_split_sizes, use_quantize_comm=False):
"""Forward function."""
ctx.group = group
ctx.output_split_sizes = output_split_sizes
ctx.input_split_sizes = input_split_sizes
ctx.use_qcomm = use_qcomm
ctx.use_quantize_comm = use_quantize_comm
world_size = torch.distributed.get_world_size(group=group)
# Bypass the function if we are using only 1 GPU.
......@@ -20,7 +20,7 @@ class _AllToAll(torch.autograd.Function):
input = input.contiguous()
if output_split_sizes is None:
# Equal split (all2all)
if use_qcomm:
if use_quantize_comm:
output = input.new_empty(
size=[input.shape[0], input.shape[1]+4],
dtype=torch.int8,
......@@ -30,7 +30,7 @@ class _AllToAll(torch.autograd.Function):
output = torch.empty_like(input)
else:
# Unequal split (all2all-v)
if use_qcomm:
if use_quantize_comm:
output = input.new_empty(
size=[sum(output_split_sizes)] + list(input.size()[1:]),
dtype=torch.int8,
......@@ -43,7 +43,7 @@ class _AllToAll(torch.autograd.Function):
device=torch.cuda.current_device(),
)
if use_qcomm:
if use_quantize_comm:
output = q_alltoall(output, input, output_split_sizes, input_split_sizes,group)
else:
torch.distributed.all_to_all_single(
......@@ -60,13 +60,13 @@ class _AllToAll(torch.autograd.Function):
"""Backward function."""
return (
None,
_AllToAll.apply(ctx.group, *grad_output, ctx.input_split_sizes, ctx.output_split_sizes, ctx.use_qcomm),
_AllToAll.apply(ctx.group, *grad_output, ctx.input_split_sizes, ctx.output_split_sizes, ctx.use_quantize_comm),
None,
None,
None,
)
def all_to_all(group, input_, output_split_sizes_=None, input_split_sizes=None, use_qcomm=False):
def all_to_all(group, input_, output_split_sizes_=None, input_split_sizes=None, use_quantize_comm=False):
"""Wrapper for autograd function"""
return _AllToAll.apply(group, input_, output_split_sizes_, input_split_sizes, use_qcomm)
return _AllToAll.apply(group, input_, output_split_sizes_, input_split_sizes, use_quantize_comm)
......@@ -40,9 +40,9 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# use_qcomm
# use_quantize_comm
args = get_args()
self.use_qcomm = args.use_qcomm
self.use_quantize_comm = args.use_quantize_comm
def collect_per_batch_state(self, state: MoEAlltoAllPerBatchState):
state.num_global_tokens_per_local_expert = getattr(
......@@ -134,7 +134,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
"before_ep_alltoall", tokens_per_expert
)
global_input_tokens = all_to_all(
self.ep_group, permutated_local_input_tokens, self.output_splits, self.input_splits, use_qcomm=self.use_qcomm
self.ep_group, permutated_local_input_tokens, self.output_splits, self.input_splits, use_quantize_comm=self.use_quantize_comm
)
return tokens_per_expert, global_input_tokens
......@@ -258,7 +258,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
# Perform expert parallel AlltoAll communication
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
permutated_local_input_tokens = all_to_all(
self.ep_group, hidden_states, self.input_splits, self.output_splits, use_qcomm=self.use_qcomm
self.ep_group, hidden_states, self.input_splits, self.output_splits, use_quantize_comm=self.use_quantize_comm
)
return permutated_local_input_tokens
......
......@@ -129,7 +129,7 @@ def _add_extra_tokenizer_args(parser):
'NullTokenizer',
'DeepSeekV2Tokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--use-qcomm',
group.add_argument('--use-quantize-comm',
default=False,
action="store_true",
help='use quantized communication')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment