Commit bfe0b4a9 authored by dongcl's avatar dongcl
Browse files

decorate _bwd_kernel_destindex_dequantize_kv with triton.jit

parent dc8a93ae
from .mappings import all_to_all
\ No newline at end of file
...@@ -5,7 +5,7 @@ from .qcomm import q_alltoall ...@@ -5,7 +5,7 @@ from .qcomm import q_alltoall
class _AllToAll(torch.autograd.Function): class _AllToAll(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, group, input, output_split_sizes, input_split_sizes): def forward(ctx, group, input, output_split_sizes, input_split_sizes, use_qcomm=False):
"""Forward function.""" """Forward function."""
ctx.group = group ctx.group = group
ctx.output_split_sizes = output_split_sizes ctx.output_split_sizes = output_split_sizes
...@@ -30,7 +30,7 @@ class _AllToAll(torch.autograd.Function): ...@@ -30,7 +30,7 @@ class _AllToAll(torch.autograd.Function):
output = torch.empty_like(input) output = torch.empty_like(input)
else: else:
# Unequal split (all2all-v) # Unequal split (all2all-v)
if use_comm: if use_qcomm:
output = input.new_empty( output = input.new_empty(
size=[sum(output_split_sizes)] + list(input.size()[1:]), size=[sum(output_split_sizes)] + list(input.size()[1:]),
dtype=torch.int8, dtype=torch.int8,
......
...@@ -71,6 +71,8 @@ def destindex_copy_quantize_kv_init_asym(K, Out, Out_scale_zero): ...@@ -71,6 +71,8 @@ def destindex_copy_quantize_kv_init_asym(K, Out, Out_scale_zero):
) )
return return
@triton.jit
def _bwd_kernel_destindex_dequantize_kv( def _bwd_kernel_destindex_dequantize_kv(
Quantized_Out, Out_scale_zero, Dequantized_Out, Quantized_Out, Out_scale_zero, Dequantized_Out,
stride_qo_bs, stride_qo_h, stride_qo_d, stride_qo_bs, stride_qo_h, stride_qo_d,
......
...@@ -38,7 +38,7 @@ class MoEAlltoAllPerBatchState: ...@@ -38,7 +38,7 @@ class MoEAlltoAllPerBatchState:
class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super.__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# use_qcomm # use_qcomm
args = get_args() args = get_args()
......
...@@ -23,6 +23,7 @@ def add_megatron_arguments_patch(parser: argparse.ArgumentParser): ...@@ -23,6 +23,7 @@ def add_megatron_arguments_patch(parser: argparse.ArgumentParser):
# add extra arguments # add extra arguments
parser = _add_extra_network_size_args(parser) parser = _add_extra_network_size_args(parser)
parser = _add_extra_training_args(parser) parser = _add_extra_training_args(parser)
parser = _add_extra_initialization_args(parser)
parser = _add_extra_distributed_args(parser) parser = _add_extra_distributed_args(parser)
parser = _add_extra_tokenizer_args(parser) parser = _add_extra_tokenizer_args(parser)
parser = _add_extra_moe_args(parser) parser = _add_extra_moe_args(parser)
...@@ -96,6 +97,14 @@ def _add_extra_training_args(parser): ...@@ -96,6 +97,14 @@ def _add_extra_training_args(parser):
return parser return parser
def _add_extra_initialization_args(parser):
group = parser.add_argument_group(title='extra initialization args')
group.add_argument('--reproduce', action='store_true',
help='reproduce train loss, need set --seed > 0.')
return parser
def _add_extra_tokenizer_args(parser): def _add_extra_tokenizer_args(parser):
# 删除原参数 # 删除原参数
remove_original_params(parser, ["tokenizer_type"]) remove_original_params(parser, ["tokenizer_type"])
......
"""Megatron initialization.""" """Megatron initialization."""
import random
import time import time
import numpy as np
import torch import torch
from datetime import timedelta from datetime import timedelta
from megatron.training import get_args from megatron.training import get_args
from megatron.core import mpu from megatron.core import mpu, tensor_parallel
def _compile_dependencies(): def _compile_dependencies():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment