Commit bfe0b4a9 authored by dongcl's avatar dongcl
Browse files

decorate _bwd_kernel_destindex_dequantize_kv with triton.jit

parent dc8a93ae
from .mappings import all_to_all
\ No newline at end of file
......@@ -5,7 +5,7 @@ from .qcomm import q_alltoall
class _AllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx, group, input, output_split_sizes, input_split_sizes):
def forward(ctx, group, input, output_split_sizes, input_split_sizes, use_qcomm=False):
"""Forward function."""
ctx.group = group
ctx.output_split_sizes = output_split_sizes
......@@ -30,7 +30,7 @@ class _AllToAll(torch.autograd.Function):
output = torch.empty_like(input)
else:
# Unequal split (all2all-v)
if use_comm:
if use_qcomm:
output = input.new_empty(
size=[sum(output_split_sizes)] + list(input.size()[1:]),
dtype=torch.int8,
......
......@@ -71,6 +71,8 @@ def destindex_copy_quantize_kv_init_asym(K, Out, Out_scale_zero):
)
return
@triton.jit
def _bwd_kernel_destindex_dequantize_kv(
Quantized_Out, Out_scale_zero, Dequantized_Out,
stride_qo_bs, stride_qo_h, stride_qo_d,
......
......@@ -38,7 +38,7 @@ class MoEAlltoAllPerBatchState:
class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
def __init__(self, *args, **kwargs):
super.__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
# use_qcomm
args = get_args()
......
......@@ -23,6 +23,7 @@ def add_megatron_arguments_patch(parser: argparse.ArgumentParser):
# add extra arguments
parser = _add_extra_network_size_args(parser)
parser = _add_extra_training_args(parser)
parser = _add_extra_initialization_args(parser)
parser = _add_extra_distributed_args(parser)
parser = _add_extra_tokenizer_args(parser)
parser = _add_extra_moe_args(parser)
......@@ -96,6 +97,14 @@ def _add_extra_training_args(parser):
return parser
def _add_extra_initialization_args(parser):
group = parser.add_argument_group(title='extra initialization args')
group.add_argument('--reproduce', action='store_true',
help='reproduce train loss, need set --seed > 0.')
return parser
def _add_extra_tokenizer_args(parser):
# 删除原参数
remove_original_params(parser, ["tokenizer_type"])
......
"""Megatron initialization."""
import random
import time
import numpy as np
import torch
from datetime import timedelta
from megatron.training import get_args
from megatron.core import mpu
from megatron.core import mpu, tensor_parallel
def _compile_dependencies():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment