Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
bfe0b4a9
Commit
bfe0b4a9
authored
May 19, 2025
by
dongcl
Browse files
decorate _bwd_kernel_destindex_dequantize_kv with triton.jit
parent
dc8a93ae
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
19 additions
and
4 deletions
+19
-4
dcu_megatron/core/tensor_parallel/__init__.py
dcu_megatron/core/tensor_parallel/__init__.py
+1
-0
dcu_megatron/core/tensor_parallel/mappings.py
dcu_megatron/core/tensor_parallel/mappings.py
+2
-2
dcu_megatron/core/tensor_parallel/qcomm.py
dcu_megatron/core/tensor_parallel/qcomm.py
+2
-0
dcu_megatron/core/transformer/moe/token_dispatcher.py
dcu_megatron/core/transformer/moe/token_dispatcher.py
+1
-1
dcu_megatron/training/arguments.py
dcu_megatron/training/arguments.py
+9
-0
dcu_megatron/training/initialize.py
dcu_megatron/training/initialize.py
+4
-1
No files found.
dcu_megatron/core/tensor_parallel/__init__.py
0 → 100644
View file @
bfe0b4a9
from
.mappings
import
all_to_all
\ No newline at end of file
dcu_megatron/core/tensor_parallel/mappings.py
View file @
bfe0b4a9
...
@@ -5,7 +5,7 @@ from .qcomm import q_alltoall
...
@@ -5,7 +5,7 @@ from .qcomm import q_alltoall
class
_AllToAll
(
torch
.
autograd
.
Function
):
class
_AllToAll
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
group
,
input
,
output_split_sizes
,
input_split_sizes
):
def
forward
(
ctx
,
group
,
input
,
output_split_sizes
,
input_split_sizes
,
use_qcomm
=
False
):
"""Forward function."""
"""Forward function."""
ctx
.
group
=
group
ctx
.
group
=
group
ctx
.
output_split_sizes
=
output_split_sizes
ctx
.
output_split_sizes
=
output_split_sizes
...
@@ -30,7 +30,7 @@ class _AllToAll(torch.autograd.Function):
...
@@ -30,7 +30,7 @@ class _AllToAll(torch.autograd.Function):
output
=
torch
.
empty_like
(
input
)
output
=
torch
.
empty_like
(
input
)
else
:
else
:
# Unequal split (all2all-v)
# Unequal split (all2all-v)
if
use_comm
:
if
use_
q
comm
:
output
=
input
.
new_empty
(
output
=
input
.
new_empty
(
size
=
[
sum
(
output_split_sizes
)]
+
list
(
input
.
size
()[
1
:]),
size
=
[
sum
(
output_split_sizes
)]
+
list
(
input
.
size
()[
1
:]),
dtype
=
torch
.
int8
,
dtype
=
torch
.
int8
,
...
...
dcu_megatron/core/tensor_parallel/qcomm.py
View file @
bfe0b4a9
...
@@ -71,6 +71,8 @@ def destindex_copy_quantize_kv_init_asym(K, Out, Out_scale_zero):
...
@@ -71,6 +71,8 @@ def destindex_copy_quantize_kv_init_asym(K, Out, Out_scale_zero):
)
)
return
return
@
triton
.
jit
def
_bwd_kernel_destindex_dequantize_kv
(
def
_bwd_kernel_destindex_dequantize_kv
(
Quantized_Out
,
Out_scale_zero
,
Dequantized_Out
,
Quantized_Out
,
Out_scale_zero
,
Dequantized_Out
,
stride_qo_bs
,
stride_qo_h
,
stride_qo_d
,
stride_qo_bs
,
stride_qo_h
,
stride_qo_d
,
...
...
dcu_megatron/core/transformer/moe/token_dispatcher.py
View file @
bfe0b4a9
...
@@ -38,7 +38,7 @@ class MoEAlltoAllPerBatchState:
...
@@ -38,7 +38,7 @@ class MoEAlltoAllPerBatchState:
class
MoEAlltoAllTokenDispatcher
(
MegatronCoreMoEAlltoAllTokenDispatcher
):
class
MoEAlltoAllTokenDispatcher
(
MegatronCoreMoEAlltoAllTokenDispatcher
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
.
__init__
(
*
args
,
**
kwargs
)
super
()
.
__init__
(
*
args
,
**
kwargs
)
# use_qcomm
# use_qcomm
args
=
get_args
()
args
=
get_args
()
...
...
dcu_megatron/training/arguments.py
View file @
bfe0b4a9
...
@@ -23,6 +23,7 @@ def add_megatron_arguments_patch(parser: argparse.ArgumentParser):
...
@@ -23,6 +23,7 @@ def add_megatron_arguments_patch(parser: argparse.ArgumentParser):
# add extra arguments
# add extra arguments
parser
=
_add_extra_network_size_args
(
parser
)
parser
=
_add_extra_network_size_args
(
parser
)
parser
=
_add_extra_training_args
(
parser
)
parser
=
_add_extra_training_args
(
parser
)
parser
=
_add_extra_initialization_args
(
parser
)
parser
=
_add_extra_distributed_args
(
parser
)
parser
=
_add_extra_distributed_args
(
parser
)
parser
=
_add_extra_tokenizer_args
(
parser
)
parser
=
_add_extra_tokenizer_args
(
parser
)
parser
=
_add_extra_moe_args
(
parser
)
parser
=
_add_extra_moe_args
(
parser
)
...
@@ -96,6 +97,14 @@ def _add_extra_training_args(parser):
...
@@ -96,6 +97,14 @@ def _add_extra_training_args(parser):
return
parser
return
parser
def
_add_extra_initialization_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'extra initialization args'
)
group
.
add_argument
(
'--reproduce'
,
action
=
'store_true'
,
help
=
'reproduce train loss, need set --seed > 0.'
)
return
parser
def
_add_extra_tokenizer_args
(
parser
):
def
_add_extra_tokenizer_args
(
parser
):
# 删除原参数
# 删除原参数
remove_original_params
(
parser
,
[
"tokenizer_type"
])
remove_original_params
(
parser
,
[
"tokenizer_type"
])
...
...
dcu_megatron/training/initialize.py
View file @
bfe0b4a9
"""Megatron initialization."""
"""Megatron initialization."""
import
random
import
time
import
time
import
numpy
as
np
import
torch
import
torch
from
datetime
import
timedelta
from
datetime
import
timedelta
from
megatron.training
import
get_args
from
megatron.training
import
get_args
from
megatron.core
import
mpu
from
megatron.core
import
mpu
,
tensor_parallel
def
_compile_dependencies
():
def
_compile_dependencies
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment