Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
dc8a93ae
Commit
dc8a93ae
authored
May 19, 2025
by
dongcl
Browse files
megatron supports qcomm; 1f1b supports dense mlp
parent
fa142de0
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
334 additions
and
6 deletions
+334
-6
dcu_megatron/core/models/gpt/fine_grained_schedule.py
dcu_megatron/core/models/gpt/fine_grained_schedule.py
+9
-0
dcu_megatron/core/tensor_parallel/mappings.py
dcu_megatron/core/tensor_parallel/mappings.py
+72
-0
dcu_megatron/core/tensor_parallel/qcomm.py
dcu_megatron/core/tensor_parallel/qcomm.py
+215
-0
dcu_megatron/core/transformer/moe/token_dispatcher.py
dcu_megatron/core/transformer/moe/token_dispatcher.py
+12
-3
dcu_megatron/core/transformer/transformer_layer.py
dcu_megatron/core/transformer/transformer_layer.py
+22
-3
dcu_megatron/training/arguments.py
dcu_megatron/training/arguments.py
+4
-0
No files found.
dcu_megatron/core/models/gpt/fine_grained_schedule.py
View file @
dc8a93ae
...
...
@@ -397,6 +397,10 @@ class DenseAttnNode(TransformerLayerNode):
)
return
hidden_states
def
dw
(
self
):
with
torch
.
cuda
.
nvtx
.
range
(
f
"
{
self
.
name
}
wgrad"
):
self
.
layer
.
_submodule_attention_dw
()
class
FakeScheduleNode
:
...
...
@@ -411,6 +415,10 @@ class DenseMlpNode(TransformerLayerNode):
def
forward_impl
(
self
,
hidden_states
):
return
self
.
layer
.
_submodule_dense_forward
(
hidden_states
)
def
dw
(
self
):
with
torch
.
cuda
.
nvtx
.
range
(
f
"
{
self
.
name
}
wgrad"
):
self
.
layer
.
_submodule_mlp_dw
()
def
build_non_moe_layer_plan
(
layer
,
event
,
chunk_state
,
comp_stream
,
com_stream
):
common_state
=
TransformerLayerState
()
...
...
@@ -418,6 +426,7 @@ def build_non_moe_layer_plan(layer, event, chunk_state, comp_stream, com_stream)
attn
.
name
=
"attn"
dispatch
=
FakeScheduleNode
()
mlp
=
DenseMlpNode
(
chunk_state
,
common_state
,
layer
,
comp_stream
,
event
)
mlp
.
name
=
"mlp"
combine
=
FakeScheduleNode
()
return
TransformerLayerSchedulePlan
(
attn
,
dispatch
,
mlp
,
combine
)
...
...
dcu_megatron/core/tensor_parallel/mappings.py
0 → 100644
View file @
dc8a93ae
import
torch
from
.qcomm
import
q_alltoall
class
_AllToAll
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
group
,
input
,
output_split_sizes
,
input_split_sizes
):
"""Forward function."""
ctx
.
group
=
group
ctx
.
output_split_sizes
=
output_split_sizes
ctx
.
input_split_sizes
=
input_split_sizes
ctx
.
use_qcomm
=
use_qcomm
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
group
)
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input
input
=
input
.
contiguous
()
if
output_split_sizes
is
None
:
# Equal split (all2all)
if
use_qcomm
:
output
=
input
.
new_empty
(
size
=
[
input
.
shape
[
0
],
input
.
shape
[
1
]
+
4
],
dtype
=
torch
.
int8
,
device
=
torch
.
cuda
.
current_device
(),
)
else
:
output
=
torch
.
empty_like
(
input
)
else
:
# Unequal split (all2all-v)
if
use_comm
:
output
=
input
.
new_empty
(
size
=
[
sum
(
output_split_sizes
)]
+
list
(
input
.
size
()[
1
:]),
dtype
=
torch
.
int8
,
device
=
torch
.
cuda
.
current_device
(),
)
else
:
output
=
input
.
new_empty
(
size
=
[
sum
(
output_split_sizes
)]
+
list
(
input
.
size
()[
1
:]),
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
if
use_qcomm
:
output
=
q_alltoall
(
output
,
input
,
output_split_sizes
,
input_split_sizes
,
group
)
else
:
torch
.
distributed
.
all_to_all_single
(
output
,
input
,
output_split_sizes
=
output_split_sizes
,
input_split_sizes
=
input_split_sizes
,
group
=
group
,
)
return
output
@
staticmethod
def
backward
(
ctx
,
*
grad_output
):
"""Backward function."""
return
(
None
,
_AllToAll
.
apply
(
ctx
.
group
,
*
grad_output
,
ctx
.
input_split_sizes
,
ctx
.
output_split_sizes
,
ctx
.
use_qcomm
),
None
,
None
,
None
,
)
def
all_to_all
(
group
,
input_
,
output_split_sizes_
=
None
,
input_split_sizes
=
None
,
use_qcomm
=
False
):
"""Wrapper for autograd function"""
return
_AllToAll
.
apply
(
group
,
input_
,
output_split_sizes_
,
input_split_sizes
,
use_qcomm
)
dcu_megatron/core/tensor_parallel/qcomm.py
0 → 100644
View file @
dc8a93ae
import
torch
import
triton
import
triton.language
as
tl
import
random
import
unittest
import
json
import
os
import
time
@
triton
.
jit
def
_fwd_kernel_destindex_copy_quantize_kv_init_asym
(
K
,
Out
,
Out_scale_zero
,
stride_k_bs
,
stride_k_h
,
stride_k_d
,
stride_o_bs
,
stride_o_h
,
stride_o_d
,
stride_os_bs
,
stride_os_h
,
stride_os_d
,
head_num
,
head_dim
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_HEAD
:
tl
.
constexpr
):
cur_index
=
tl
.
program_id
(
0
)
offs_h
=
tl
.
arange
(
0
,
BLOCK_HEAD
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
dest_index
=
cur_index
m1
=
offs_h
[:,
None
]
<
head_num
m2
=
offs_d
[
None
,:]
<
head_dim
mask
=
m1
&
m2
src_data
=
tl
.
load
(
K
+
cur_index
*
stride_k_bs
+
offs_h
[:,
None
]
*
stride_k_h
+
stride_k_d
*
offs_d
[
None
,
:],
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
src_data_max
=
tl
.
max
(
src_data
,
axis
=
1
,
keep_dims
=
True
)
src_data_min
=
tl
.
min
(
src_data
,
axis
=
1
,
keep_dims
=
True
)
data_scale
=
(
src_data_max
-
src_data_min
)
/
255.0
data_zero
=
(
-
1
*
src_data_min
/
data_scale
).
to
(
tl
.
int32
)
q_src_data
=
(
tl
.
clamp
((
src_data
/
data_scale
).
to
(
tl
.
int32
).
to
(
tl
.
float32
)
+
data_zero
.
to
(
tl
.
float32
),
0.0
,
255.0
).
to
(
tl
.
int32
)
-
128
).
to
(
tl
.
int8
)
data_scale
=
data_scale
.
to
(
Out_scale_zero
.
dtype
.
element_ty
)
data_zero
=
data_zero
.
to
(
Out_scale_zero
.
dtype
.
element_ty
)
o_ptrs
=
Out
+
dest_index
*
stride_o_bs
+
stride_o_h
*
offs_h
[:,
None
]
+
stride_o_d
*
offs_d
[
None
,
:]
os_ptrs
=
Out_scale_zero
+
dest_index
*
stride_os_bs
+
stride_os_h
*
offs_h
[:,
None
]
oz_ptrs
=
Out_scale_zero
+
dest_index
*
stride_os_bs
+
stride_os_h
*
offs_h
[:,
None
]
+
1
tl
.
store
(
o_ptrs
,
q_src_data
,
mask
=
mask
)
tl
.
store
(
os_ptrs
,
data_scale
,
mask
=
m1
)
tl
.
store
(
oz_ptrs
,
data_zero
,
mask
=
m1
)
@
torch
.
no_grad
()
def
destindex_copy_quantize_kv_init_asym
(
K
,
Out
,
Out_scale_zero
):
bs_seq
=
K
.
shape
[
0
]
head_num
=
K
.
shape
[
1
]
head_dim
=
K
.
shape
[
2
]
assert
K
.
shape
[
1
]
==
Out
.
shape
[
1
]
and
K
.
shape
[
2
]
==
Out
.
shape
[
2
]
BLOCK_HEAD
=
triton
.
next_power_of_2
(
head_num
)
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
head_dim
)
grid
=
(
bs_seq
,)
num_warps
=
1
_fwd_kernel_destindex_copy_quantize_kv_init_asym
[
grid
](
K
,
Out
,
Out_scale_zero
,
K
.
stride
(
0
),
K
.
stride
(
1
),
K
.
stride
(
2
),
Out
.
stride
(
0
),
Out
.
stride
(
1
),
Out
.
stride
(
2
),
Out_scale_zero
.
stride
(
0
),
Out_scale_zero
.
stride
(
1
),
Out_scale_zero
.
stride
(
2
),
head_num
,
head_dim
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_HEAD
=
BLOCK_HEAD
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
return
def
_bwd_kernel_destindex_dequantize_kv
(
Quantized_Out
,
Out_scale_zero
,
Dequantized_Out
,
stride_qo_bs
,
stride_qo_h
,
stride_qo_d
,
stride_os_bs
,
stride_os_h
,
stride_os_d
,
stride_do_bs
,
stride_do_h
,
stride_do_d
,
head_num
,
head_dim
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_HEAD
:
tl
.
constexpr
):
cur_index
=
tl
.
program_id
(
0
)
offs_h
=
tl
.
arange
(
0
,
BLOCK_HEAD
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
scales_dtype
=
Out_scale_zero
.
dtype
.
element_ty
dest_index
=
cur_index
m1
=
offs_h
[:,
None
]
<
head_num
m2
=
offs_d
[
None
,:]
<
head_dim
mask
=
m1
&
m2
# Load quantized data
q_data
=
tl
.
load
(
Quantized_Out
+
dest_index
*
stride_qo_bs
+
offs_h
[:,
None
]
*
stride_qo_h
+
stride_qo_d
*
offs_d
[
None
,
:],
mask
=
mask
,
other
=
0
)
# Load scale and zero point
data_scale
=
tl
.
load
(
Out_scale_zero
+
dest_index
*
stride_os_bs
+
stride_os_h
*
offs_h
[:,
None
],
mask
=
m1
,
other
=
1.0
)
data_zero
=
tl
.
load
(
Out_scale_zero
+
dest_index
*
stride_os_bs
+
stride_os_h
*
offs_h
[:,
None
]
+
1
,
mask
=
m1
,
other
=
0
)
# Dequantize
dequantized_data
=
(
q_data
.
to
(
tl
.
int32
)
+
128
-
data_zero
.
to
(
tl
.
int32
)).
to
(
scales_dtype
)
*
data_scale
# Store dequantized data
out_ptrs
=
Dequantized_Out
+
dest_index
*
stride_do_bs
+
stride_do_h
*
offs_h
[:,
None
]
+
stride_do_d
*
offs_d
[
None
,
:]
tl
.
store
(
out_ptrs
,
dequantized_data
,
mask
=
mask
)
@
torch
.
no_grad
()
def
destindex_dequantize_kv
(
Quantized_Out
,
Out_scale_zero
,
Dequantized_Out
):
bs_seq
=
Quantized_Out
.
shape
[
0
]
head_num
=
Quantized_Out
.
shape
[
1
]
head_dim
=
Quantized_Out
.
shape
[
2
]
assert
Quantized_Out
.
shape
[
1
]
==
Dequantized_Out
.
shape
[
1
]
and
Quantized_Out
.
shape
[
2
]
==
Dequantized_Out
.
shape
[
2
]
BLOCK_HEAD
=
triton
.
next_power_of_2
(
head_num
)
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
head_dim
)
grid
=
(
bs_seq
,)
num_warps
=
1
_bwd_kernel_destindex_dequantize_kv
[
grid
](
Quantized_Out
,
Out_scale_zero
,
Dequantized_Out
,
Quantized_Out
.
stride
(
0
),
Quantized_Out
.
stride
(
1
),
Quantized_Out
.
stride
(
2
),
Out_scale_zero
.
stride
(
0
),
Out_scale_zero
.
stride
(
1
),
Out_scale_zero
.
stride
(
2
),
Dequantized_Out
.
stride
(
0
),
Dequantized_Out
.
stride
(
1
),
Dequantized_Out
.
stride
(
2
),
head_num
,
head_dim
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_HEAD
=
BLOCK_HEAD
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
@
torch
.
no_grad
()
def
fp16_to_int8s
(
fp16_tensor
):
fp16_bytes
=
fp16_tensor
.
contiguous
().
view
(
torch
.
int8
)
int8_high
=
fp16_bytes
[::
2
]
# 高 8 位
int8_low
=
fp16_bytes
[
1
::
2
]
# 低 8 位
return
int8_high
.
unsqueeze
(
1
),
int8_low
.
unsqueeze
(
1
)
@
torch
.
no_grad
()
def
int8s_to_fp16
(
int8_high
,
int8_low
):
fp16_bytes
=
torch
.
stack
([
int8_high
,
int8_low
],
dim
=-
1
).
view
(
torch
.
int16
)
return
fp16_bytes
.
view
(
torch
.
bfloat16
)
def
_alltoall
(
group
,
input
,
output_split_sizes
,
input_split_sizes
):
input
=
input
.
contiguous
()
if
output_split_sizes
is
None
:
# Equal split (all2all)
output
=
torch
.
empty_like
(
input
)
else
:
# Unequal split (all2all-v)
output
=
input
.
new_empty
(
size
=
[
sum
(
output_split_sizes
)]
+
list
(
input
.
size
()[
1
:]),
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
torch
.
distributed
.
all_to_all_single
(
output
,
input
,
output_split_sizes
=
output_split_sizes
,
input_split_sizes
=
input_split_sizes
,
group
=
group
,
)
return
output
def
q_alltoall
(
output
,
input
,
output_split_sizes
,
input_split_sizes
,
group
):
t
,
s
=
input
.
shape
[
0
],
input
.
shape
[
1
]
input_buffer_int8
=
torch
.
empty
((
t
,
1
,
s
),
dtype
=
torch
.
int8
,
device
=
"cuda"
)
buffer_scales
=
torch
.
empty
((
t
,
1
,
2
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
input_q
=
input
.
unsqueeze
(
1
)
destindex_copy_quantize_kv_init_asym
(
input_q
,
input_buffer_int8
,
buffer_scales
,
)
input_buffer_int8
=
input_buffer_int8
.
squeeze
()
buffer_scales
=
buffer_scales
.
squeeze
()
buffer_scales_h
,
buffer_scales_l
=
fp16_to_int8s
(
buffer_scales
[:,
0
])
buffer_shift_h
,
buffer_shift_l
=
fp16_to_int8s
(
buffer_scales
[:,
1
])
input_all
=
torch
.
cat
([
input_buffer_int8
,
buffer_scales_h
,
buffer_scales_l
,
buffer_shift_h
,
buffer_shift_l
],
dim
=
1
)
torch
.
distributed
.
all_to_all_single
(
output
,
input_all
,
output_split_sizes
=
output_split_sizes
,
input_split_sizes
=
input_split_sizes
,
group
=
group
,
)
scale
=
int8s_to_fp16
(
output
[:,
-
4
],
output
[:,
-
3
])
shift
=
int8s_to_fp16
(
output
[:,
-
2
],
output
[:,
-
1
])
scales
=
torch
.
cat
([
scale
,
shift
],
dim
=
1
).
unsqueeze
(
1
)
deq_out
=
torch
.
empty
((
output
.
shape
[
0
],
1
,
output
.
shape
[
1
]
-
4
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
destindex_dequantize_kv
(
output
[:,:
-
4
].
unsqueeze
(
1
),
scales
,
deq_out
)
return
deq_out
.
squeeze
()
dcu_megatron/core/transformer/moe/token_dispatcher.py
View file @
dc8a93ae
...
...
@@ -3,8 +3,8 @@ from typing import Optional, Tuple
import
torch
from
megatron.training
import
get_args
from
megatron.core.tensor_parallel
import
(
all_to_all
,
gather_from_sequence_parallel_region
,
reduce_scatter_to_sequence_parallel_region
,
)
...
...
@@ -15,6 +15,8 @@ from megatron.core.transformer.moe.moe_utils import (
)
from
megatron.core.transformer.moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
as
MegatronCoreMoEAlltoAllTokenDispatcher
from
dcu_megatron.core.tensor_parallel
import
all_to_all
# decouple perbatch state from MoEAlltoAllTokenDispatcher
class
MoEAlltoAllPerBatchState
:
...
...
@@ -35,6 +37,13 @@ class MoEAlltoAllPerBatchState:
class
MoEAlltoAllTokenDispatcher
(
MegatronCoreMoEAlltoAllTokenDispatcher
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
.
__init__
(
*
args
,
**
kwargs
)
# use_qcomm
args
=
get_args
()
self
.
use_qcomm
=
args
.
use_qcomm
def
collect_per_batch_state
(
self
,
state
:
MoEAlltoAllPerBatchState
):
state
.
num_global_tokens_per_local_expert
=
getattr
(
self
,
"num_global_tokens_per_local_expert"
,
None
...
...
@@ -125,7 +134,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
"before_ep_alltoall"
,
tokens_per_expert
)
global_input_tokens
=
all_to_all
(
self
.
ep_group
,
permutated_local_input_tokens
,
self
.
output_splits
,
self
.
input_splits
self
.
ep_group
,
permutated_local_input_tokens
,
self
.
output_splits
,
self
.
input_splits
,
use_qcomm
=
self
.
use_qcomm
)
return
tokens_per_expert
,
global_input_tokens
...
...
@@ -249,7 +258,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
# Perform expert parallel AlltoAll communication
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
permutated_local_input_tokens
=
all_to_all
(
self
.
ep_group
,
hidden_states
,
self
.
input_splits
,
self
.
output_splits
self
.
ep_group
,
hidden_states
,
self
.
input_splits
,
self
.
output_splits
,
use_qcomm
=
self
.
use_qcomm
)
return
permutated_local_input_tokens
...
...
dcu_megatron/core/transformer/transformer_layer.py
View file @
dc8a93ae
...
...
@@ -10,6 +10,7 @@ from megatron.core.utils import (
deprecate_inference_params
,
make_viewless_tensor
,
)
from
megatron.core.transformer.moe.moe_layer
import
MoELayer
from
megatron.core.transformer.transformer_layer
import
TransformerLayer
as
MegatronCoreTransformerLayer
from
dcu_megatron.core.transformer.utils
import
SubmoduleCallables
,
TransformerLayerSubmoduleCallables
...
...
@@ -32,6 +33,23 @@ class TransformerLayer(MegatronCoreTransformerLayer):
*
,
inference_params
:
Optional
[
Any
]
=
None
,
):
if
not
isinstance
(
self
.
mlp
,
MoELayer
):
return
super
().
forward
(
hidden_states
=
hidden_states
,
context
=
context
,
context_mask
=
context_mask
,
attention_mask
=
attention_mask
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
attention_bias
=
attention_bias
,
inference_context
=
inference_context
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
inference_params
=
inference_params
,
)
(
hidden_states
,
pre_mlp_layernorm_output
,
...
...
@@ -259,10 +277,11 @@ class TransformerLayer(MegatronCoreTransformerLayer):
return
output
def
_submodule_attention_
router_compound_
dw
(
self
):
def
_submodule_attention_dw
(
self
):
self
.
self_attention
.
backward_dw
()
# raise NotImplementedError("Not implemented")
def
_submodule_attention_router_compound_dw
(
self
):
self
.
_submodule_attention_dw
()
def
_submodule_mlp_dw
(
self
):
self
.
mlp
.
backward_dw
()
# raise NotImplementedError("Not implemented")
dcu_megatron/training/arguments.py
View file @
dc8a93ae
...
...
@@ -120,6 +120,10 @@ def _add_extra_tokenizer_args(parser):
'NullTokenizer'
,
'DeepSeekV2Tokenizer'
],
help
=
'What type of tokenizer to use.'
)
group
.
add_argument
(
'--use-qcomm'
,
default
=
False
,
action
=
"store_true"
,
help
=
'use quantized communication'
)
return
parser
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment