Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
3b081313
Commit
3b081313
authored
Jun 09, 2025
by
dongcl
Browse files
add description for quantized communication
parent
9b681ba5
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
20 additions
and
13 deletions
+20
-13
README.md
README.md
+7
-0
dcu_megatron/core/tensor_parallel/mappings.py
dcu_megatron/core/tensor_parallel/mappings.py
+8
-8
dcu_megatron/core/transformer/moe/token_dispatcher.py
dcu_megatron/core/transformer/moe/token_dispatcher.py
+4
-4
dcu_megatron/training/arguments.py
dcu_megatron/training/arguments.py
+1
-1
No files found.
README.md
View file @
3b081313
...
@@ -77,6 +77,12 @@ def unpermute(
...
@@ -77,6 +77,12 @@ def unpermute(
```
```
+
项目支持通过split-bw进行dw拆分,用于实现更好的overlap。当前从测试结果看,开启split-bw,效果欠佳,待进一步优化。
+
项目支持通过split-bw进行dw拆分,用于实现更好的overlap。当前从测试结果看,开启split-bw,效果欠佳,待进一步优化。
### 项目支持量化通信
+
项目支持量化通信,对all-to-all通信数据进行低精度表示,减少通信量。如果需要使用该特性,需要启动脚本中加入如下参数:
```
--use-quantize-comm
```
## 使用方式
## 使用方式
### 项目下载
### 项目下载
...
@@ -97,6 +103,7 @@ def unpermute(
...
@@ -97,6 +103,7 @@ def unpermute(
2.
3 将Megatron-LM离线代码包解压到dcu_megatron目录下的Megatron-LM目录
2.
3 将Megatron-LM离线代码包解压到dcu_megatron目录下的Megatron-LM目录
### 项目使用
### 项目使用
在使用时,进入到examples目录下,有相关模型执行脚本,所用数据集请自行下载:https://r0ddbu55vzx.feishu.cn/drive/folder/ZxHHfCoX4lg75td2hTqcmiAin3g
在使用时,进入到examples目录下,有相关模型执行脚本,所用数据集请自行下载:https://r0ddbu55vzx.feishu.cn/drive/folder/ZxHHfCoX4lg75td2hTqcmiAin3g
```
```
...
...
dcu_megatron/core/tensor_parallel/mappings.py
View file @
3b081313
...
@@ -5,12 +5,12 @@ from .qcomm import q_alltoall
...
@@ -5,12 +5,12 @@ from .qcomm import q_alltoall
class
_AllToAll
(
torch
.
autograd
.
Function
):
class
_AllToAll
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
group
,
input
,
output_split_sizes
,
input_split_sizes
,
use_qcomm
=
False
):
def
forward
(
ctx
,
group
,
input
,
output_split_sizes
,
input_split_sizes
,
use_q
uantize_
comm
=
False
):
"""Forward function."""
"""Forward function."""
ctx
.
group
=
group
ctx
.
group
=
group
ctx
.
output_split_sizes
=
output_split_sizes
ctx
.
output_split_sizes
=
output_split_sizes
ctx
.
input_split_sizes
=
input_split_sizes
ctx
.
input_split_sizes
=
input_split_sizes
ctx
.
use_qcomm
=
use_qcomm
ctx
.
use_q
uantize_
comm
=
use_q
uantize_
comm
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
group
)
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
group
)
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
...
@@ -20,7 +20,7 @@ class _AllToAll(torch.autograd.Function):
...
@@ -20,7 +20,7 @@ class _AllToAll(torch.autograd.Function):
input
=
input
.
contiguous
()
input
=
input
.
contiguous
()
if
output_split_sizes
is
None
:
if
output_split_sizes
is
None
:
# Equal split (all2all)
# Equal split (all2all)
if
use_qcomm
:
if
use_q
uantize_
comm
:
output
=
input
.
new_empty
(
output
=
input
.
new_empty
(
size
=
[
input
.
shape
[
0
],
input
.
shape
[
1
]
+
4
],
size
=
[
input
.
shape
[
0
],
input
.
shape
[
1
]
+
4
],
dtype
=
torch
.
int8
,
dtype
=
torch
.
int8
,
...
@@ -30,7 +30,7 @@ class _AllToAll(torch.autograd.Function):
...
@@ -30,7 +30,7 @@ class _AllToAll(torch.autograd.Function):
output
=
torch
.
empty_like
(
input
)
output
=
torch
.
empty_like
(
input
)
else
:
else
:
# Unequal split (all2all-v)
# Unequal split (all2all-v)
if
use_qcomm
:
if
use_q
uantize_
comm
:
output
=
input
.
new_empty
(
output
=
input
.
new_empty
(
size
=
[
sum
(
output_split_sizes
)]
+
list
(
input
.
size
()[
1
:]),
size
=
[
sum
(
output_split_sizes
)]
+
list
(
input
.
size
()[
1
:]),
dtype
=
torch
.
int8
,
dtype
=
torch
.
int8
,
...
@@ -43,7 +43,7 @@ class _AllToAll(torch.autograd.Function):
...
@@ -43,7 +43,7 @@ class _AllToAll(torch.autograd.Function):
device
=
torch
.
cuda
.
current_device
(),
device
=
torch
.
cuda
.
current_device
(),
)
)
if
use_qcomm
:
if
use_q
uantize_
comm
:
output
=
q_alltoall
(
output
,
input
,
output_split_sizes
,
input_split_sizes
,
group
)
output
=
q_alltoall
(
output
,
input
,
output_split_sizes
,
input_split_sizes
,
group
)
else
:
else
:
torch
.
distributed
.
all_to_all_single
(
torch
.
distributed
.
all_to_all_single
(
...
@@ -60,13 +60,13 @@ class _AllToAll(torch.autograd.Function):
...
@@ -60,13 +60,13 @@ class _AllToAll(torch.autograd.Function):
"""Backward function."""
"""Backward function."""
return
(
return
(
None
,
None
,
_AllToAll
.
apply
(
ctx
.
group
,
*
grad_output
,
ctx
.
input_split_sizes
,
ctx
.
output_split_sizes
,
ctx
.
use_qcomm
),
_AllToAll
.
apply
(
ctx
.
group
,
*
grad_output
,
ctx
.
input_split_sizes
,
ctx
.
output_split_sizes
,
ctx
.
use_q
uantize_
comm
),
None
,
None
,
None
,
None
,
None
,
None
,
)
)
def
all_to_all
(
group
,
input_
,
output_split_sizes_
=
None
,
input_split_sizes
=
None
,
use_qcomm
=
False
):
def
all_to_all
(
group
,
input_
,
output_split_sizes_
=
None
,
input_split_sizes
=
None
,
use_q
uantize_
comm
=
False
):
"""Wrapper for autograd function"""
"""Wrapper for autograd function"""
return
_AllToAll
.
apply
(
group
,
input_
,
output_split_sizes_
,
input_split_sizes
,
use_qcomm
)
return
_AllToAll
.
apply
(
group
,
input_
,
output_split_sizes_
,
input_split_sizes
,
use_q
uantize_
comm
)
dcu_megatron/core/transformer/moe/token_dispatcher.py
View file @
3b081313
...
@@ -40,9 +40,9 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
...
@@ -40,9 +40,9 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
# use_qcomm
# use_q
uantize_
comm
args
=
get_args
()
args
=
get_args
()
self
.
use_qcomm
=
args
.
use_qcomm
self
.
use_q
uantize_
comm
=
args
.
use_q
uantize_
comm
def
collect_per_batch_state
(
self
,
state
:
MoEAlltoAllPerBatchState
):
def
collect_per_batch_state
(
self
,
state
:
MoEAlltoAllPerBatchState
):
state
.
num_global_tokens_per_local_expert
=
getattr
(
state
.
num_global_tokens_per_local_expert
=
getattr
(
...
@@ -134,7 +134,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
...
@@ -134,7 +134,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
"before_ep_alltoall"
,
tokens_per_expert
"before_ep_alltoall"
,
tokens_per_expert
)
)
global_input_tokens
=
all_to_all
(
global_input_tokens
=
all_to_all
(
self
.
ep_group
,
permutated_local_input_tokens
,
self
.
output_splits
,
self
.
input_splits
,
use_qcomm
=
self
.
use_qcomm
self
.
ep_group
,
permutated_local_input_tokens
,
self
.
output_splits
,
self
.
input_splits
,
use_q
uantize_
comm
=
self
.
use_q
uantize_
comm
)
)
return
tokens_per_expert
,
global_input_tokens
return
tokens_per_expert
,
global_input_tokens
...
@@ -258,7 +258,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
...
@@ -258,7 +258,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
# Perform expert parallel AlltoAll communication
# Perform expert parallel AlltoAll communication
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
permutated_local_input_tokens
=
all_to_all
(
permutated_local_input_tokens
=
all_to_all
(
self
.
ep_group
,
hidden_states
,
self
.
input_splits
,
self
.
output_splits
,
use_qcomm
=
self
.
use_qcomm
self
.
ep_group
,
hidden_states
,
self
.
input_splits
,
self
.
output_splits
,
use_q
uantize_
comm
=
self
.
use_q
uantize_
comm
)
)
return
permutated_local_input_tokens
return
permutated_local_input_tokens
...
...
dcu_megatron/training/arguments.py
View file @
3b081313
...
@@ -129,7 +129,7 @@ def _add_extra_tokenizer_args(parser):
...
@@ -129,7 +129,7 @@ def _add_extra_tokenizer_args(parser):
'NullTokenizer'
,
'NullTokenizer'
,
'DeepSeekV2Tokenizer'
],
'DeepSeekV2Tokenizer'
],
help
=
'What type of tokenizer to use.'
)
help
=
'What type of tokenizer to use.'
)
group
.
add_argument
(
'--use-qcomm'
,
group
.
add_argument
(
'--use-q
uantize-
comm'
,
default
=
False
,
default
=
False
,
action
=
"store_true"
,
action
=
"store_true"
,
help
=
'use quantized communication'
)
help
=
'use quantized communication'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment