Commit 12415ddb authored by dongchy920's avatar dongchy920
Browse files

Initial commit

parents
Pipeline #2830 canceled with stages
from .llama_model import LLaMAModel
import torch
from megatron_mini import get_args
from megatron_mini.model.module import MegatronModule
from megatron_mini.model.transformer import LLaMATransformer
class LLaMAModel(MegatronModule):
"""Code Generation Model for Multilingual Program Synthesis."""
def __init__(self, parallel_output=False):
super(LLaMAModel, self).__init__()
args = get_args()
self.parallel_output = parallel_output
self._language_model_key = "llama_model"
self.language_model = LLaMATransformer(
init_method=lambda x:x,
output_layer_init_method=lambda x:x
)
def forward(self, tokens: torch.Tensor, start_pos: int, return_hidden=False):
# Language model.
lm_logits = self.language_model(tokens, start_pos, return_hidden)
return lm_logits
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
if self._language_model_key in state_dict:
state_dict = state_dict[self._language_model_key]
self.language_model.load_state_dict(state_dict, strict=strict)
\ No newline at end of file
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Megatron Module"""
import torch
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from megatron_mini import get_args
from megatron_mini.core import mpu, tensor_parallel
class MegatronModule(torch.nn.Module):
"""Megatron specific extensions of torch Module with support
for pipelining."""
def __init__(self, share_word_embeddings=True):
super(MegatronModule, self).__init__()
self.share_word_embeddings = share_word_embeddings
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""Use this function to override the state dict for
saving checkpoints."""
return self.state_dict(prefix=prefix, keep_vars=keep_vars)
def word_embeddings_weight(self):
if self.pre_process:
return self.language_model.embedding.word_embeddings.weight
else:
if not self.share_word_embeddings:
raise Exception('word_embeddings_weight() called for last '
'stage, but share_word_embeddings is false')
return self.word_embeddings.weight
def initialize_word_embeddings(self, init_method_normal):
args = get_args()
if not self.share_word_embeddings:
raise Exception('initialize_word_embeddings() was called but '
'share_word_embeddings is false')
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. Nothing to do if we aren't
# using pipeline parallelism.
if args.pipeline_model_parallel_size == 1:
return
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
if mpu.is_pipeline_last_stage() and \
not self.pre_process:
assert not mpu.is_pipeline_first_stage()
self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
args.padded_vocab_size, args.hidden_size,
init_method=init_method_normal(args.init_method_std),
params_dtype=args.params_dtype,
use_cpu_initialization=args.use_cpu_initialization,
perform_initialization=args.perform_initialization)
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
# Zero out initial weights for decoder embedding.
# NOTE: We don't currently support T5 with the interleaved schedule.
if not mpu.is_pipeline_first_stage(ignore_virtual=True) and \
self.pre_process:
self.language_model.embedding.zero_parameters()
if not torch.distributed.is_initialized():
if not getattr(MegatronModule, "embedding_warning_printed", False):
print("WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong.")
MegatronModule.embedding_warning_printed = True
return
# Ensure that first and last stages have the same initial parameter
# values.
if mpu.is_rank_in_embedding_group():
torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group())
# Ensure that encoder(first stage) and decoder(split stage) position
# embeddings have the same initial parameter values
# NOTE: We don't currently support T5 with the interleaved schedule.
if mpu.is_rank_in_position_embedding_group() and \
args.pipeline_model_parallel_split_rank is not None:
# TODO: Support tokentype embedding.
self.language_model.embedding.cuda()
position_embeddings = self.language_model.embedding.position_embeddings
torch.distributed.all_reduce(position_embeddings.weight.data,
group=mpu.get_position_embedding_group())
import enum
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
class LayerType(enum.Enum):
encoder = 1
decoder = 2
class AttnType(enum.Enum):
self_attn = 1
cross_attn = 2
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
class PositionEmbeddingType(enum.Enum):
rotary = 1
absolute = 2
alibi = 3
\ No newline at end of file
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Transformer."""
import math
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional, Union
from megatron_mini import get_args, core
from .module import MegatronModule
from megatron_mini.core import mpu, tensor_parallel
from megatron_mini.model.module import AttnMaskType
def print_rank_0(message):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True, file=sys.stderr)
else:
print(message, flush=True, file=sys.stderr)
try:
from einops import rearrange
except ImportError:
rearrange = None
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
print_rank_0(f"WARNING: FlashAttention is not available")
flash_attn_varlen_func = None
try:
from flash_attn.flash_attn_interface import flash_attn_with_kvcache
except ImportError:
print_rank_0(f"WARNING: FlashAttention is not available")
flash_attn_with_kvcache = None
class FusedScaleMaskSoftmax(nn.Module):
def __init__(
self,
input_in_fp16,
input_in_bf16,
attn_mask_type,
scaled_masked_softmax_fusion,
mask_func,
softmax_in_fp32,
scale,
):
super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16
assert not (
self.input_in_fp16 and self.input_in_bf16
), "both fp16 and bf16 flags cannot be active at the same time."
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
assert (
self.scale is None or softmax_in_fp32
), "softmax should be in fp32 when scaled"
def forward(self, input, mask):
# [b, np, sq, sk]
assert input.dim() == 4
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()
if self.scale is not None:
input = input * self.scale
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs
def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
def _args_to_kwargs():
args = get_args()
common_kwargs = {
"params_dtype": args.params_dtype,
"use_cpu_initialization": args.use_cpu_initialization,
"perform_initialization": False,
"gradient_accumulation_fusion": False,
"sequence_parallel_enabled": args.sequence_parallel,
}
return common_kwargs
class FlashSelfAttention(torch.nn.Module):
def __init__(self, causal=False, softmax_scale=None):
super().__init__()
assert flash_attn_varlen_func is not None, ('Please install FlashAttention first, '
'e.g., with pip install flash-attn')
assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
self.causal = causal
self.softmax_scale = softmax_scale
def forward(self, q, k, v):
"""Implements the softmax like attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v)))
assert all((i.is_cuda for i in (q,k,v)))
if flash_attn_with_kvcache is not None:
"""
q: (batch_size, seqlen_q, nheads, headdim)
k_cache: (batch_size, seqlen_kv, nheads_k, headdim)
v_cache: (batch_size, seqlen_kv, nheads_k, headdim)
we do not pass k and v to flash_attn_with_kvcache, because our k and v are packed with kv cache
"""
context = flash_attn_with_kvcache(
q,
k_cache=k,
v_cache=v,
k=None,
v=None,
cache_seqlens=None,
softmax_scale=self.softmax_scale,
causal=True,
)
return context
else:
batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k =k.shape[1]
q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
is_causal = seqlen_q == seqlen_k
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k.device)
output = flash_attn_varlen_func(
q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
0.0,
softmax_scale=self.softmax_scale, causal=is_causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
return output
class CoreAttention(MegatronModule):
def __init__(self, layer_number,
attn_mask_type=AttnMaskType.padding):
super(CoreAttention, self).__init__()
args = get_args()
self.fp16 = args.fp16
self.bf16 = args.bf16
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = False
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number)
self.attn_mask_type = attn_mask_type
self.sequence_parallel = args.sequence_parallel
projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = core.utils.divide(projection_size,
world_size)
self.hidden_size_per_attention_head = core.utils.divide(
projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = core.utils.divide(
args.num_attention_heads, world_size)
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16, self.bf16,
self.attn_mask_type,
args.masked_softmax_fusion,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff)
def forward(self, query_layer, key_layer,
value_layer, attention_mask):
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = mpu.get_global_memory_buffer().get_tensor(
(output_size[0]*output_size[1], output_size[2], output_size[3]),
query_layer.dtype, "mpu")
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_input_buffer,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor))
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ===========================
# Attention probs
# ===========================
if os.getenv('AIXCODER_DEBUG') == 'ON':
mp_0 = query_layer.shape[1]
stats = torch.stack([
torch.mean(attention_scores[:,:mp_0]).to(torch.float32),
torch.std(attention_scores[:,:mp_0]).to(torch.float32),
torch.max(attention_scores[:,:mp_0]).to(torch.float32)
]).detach().cpu().numpy()
print_rank_0(
f"\nAttention - scores before softmax".ljust(40) + f": {stats}, {attention_scores[:,:mp_0].dtype}, {attention_scores[:,:mp_0].shape}"
)
# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores,
attention_mask)
if os.getenv('AIXCODER_DEBUG') == 'ON':
mp_0 = query_layer.shape[1]
stats = torch.stack([
torch.mean(attention_probs[:,:mp_0]).to(torch.float32),
torch.std(attention_probs[:,:mp_0]).to(torch.float32),
torch.max(attention_probs[:,:mp_0]).to(torch.float32)
]).detach().cpu().numpy()
print_rank_0(
f"\nAttention_probs".ljust(40) + f": {stats}, {attention_probs[:,:mp_0].dtype}, {attention_probs[:,:mp_0].shape}"
)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3))
# change view [sk, b * np, hn]
value_layer = value_layer.view(value_layer.size(0),
output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
"""
Implementation for LLaMA
"""
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
# @torch.jit.script
def apply_rotary_pos_emb(x, cos, sin):
# Handle a possible sequence length mismatch in between q and k
cos = cos[:x.shape[0], :, :, :]
sin = sin[:x.shape[0], :, :, :]
part_1 = x * cos
x1, x2 = x.chunk(2, dim=-1)
part_2 = torch.cat((-x2, x1), dim=-1) * sin
return part_1 + part_2
class RotaryEmbedding(MegatronModule):
def __init__(self, seq_dimension=0, rope_theta=10000, *_, **__):
super().__init__()
args = get_args()
self.args = args
self.seq_dimension = seq_dimension
dim_model = args.hidden_size // args.num_attention_heads
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim_model, 2).float() / dim_model))
# persistent: whether the buffer is part of this module's :attr:`state_dict`
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._seq_len_cached = self.args.seq_length
t = torch.arange(
self._seq_len_cached, dtype=torch.float32
)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self._cos_cached = emb.cos()[:, None, None, :]
self._sin_cached = emb.sin()[:, None, None, :]
def _update_cos_sin_tables(self, x):
seq_len = x.shape[self.seq_dimension]
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seq_len != self._seq_len_cached
or self._cos_cached.device != x.device
or self._cos_cached.dtype != x.dtype
):
self._seq_len_cached = seq_len
t = torch.arange(
x.shape[self.seq_dimension], device=x.device, dtype=torch.float32
)
freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
if self.seq_dimension == 1:
self._cos_cached = emb.cos()[None, :, None, :].to(x.dtype)
self._sin_cached = emb.sin()[None, :, None, :].to(x.dtype)
elif self.seq_dimension == 0:
self._cos_cached = emb.cos()[:, None, None, :].to(x.dtype)
self._sin_cached = emb.sin()[:, None, None, :].to(x.dtype)
else:
raise NotImplementedError
return self._cos_cached, self._sin_cached
def forward(
self, query_layer: torch.Tensor, key_layer: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
query_layer: [seq_len, bsz, local_num_heads, heads_dim]
"""
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
key_layer
)
return (
apply_rotary_pos_emb(query_layer, self._cos_cached, self._sin_cached),
apply_rotary_pos_emb(key_layer, self._cos_cached, self._sin_cached),
)
def set_devices_dtype(self, x):
self._cos_cached = self._cos_cached.to(x.dtype).to(x.device)
self._sin_cached = self._sin_cached.to(x.dtype).to(x.device)
def get_freqs_cis(self, h_shape, start_pos=0):
seq_len = h_shape[self.seq_dimension]
if self.seq_dimension == 1:
return torch.stack((self._cos_cached[:, start_pos: start_pos + seq_len], self._sin_cached[:, start_pos: start_pos + seq_len]), dim=0)
elif self.seq_dimension == 0:
return torch.stack((self._cos_cached[start_pos: start_pos + seq_len], self._sin_cached[start_pos: start_pos + seq_len]), dim=0)
else:
raise NotImplementedError
@staticmethod
# @torch.jit.script
def apply_rotary(
query_layer: torch.Tensor, key_layer: torch.Tensor, freqs_cis: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
cos = freqs_cis[0]
sin = freqs_cis[1]
# handle short query
q_part_1 = query_layer * cos[-query_layer.shape[0]:]
q_x1, q_x2 = query_layer.chunk(2, dim=-1)
q_part_2 = torch.cat((-q_x2, q_x1), dim=-1) * sin[-query_layer.shape[0]:]
k_part_1 = key_layer * cos
k_x1, k_x2 = key_layer.chunk(2, dim=-1)
k_part_2 = torch.cat((-k_x2, k_x1), dim=-1) * sin
return q_part_1 + q_part_2, k_part_1 + k_part_2
class LinearScalingRotaryEmbedding(RotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def __init__(self, seq_dimension=0, rope_theta=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(seq_dimension, rope_theta, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self._cos_cached = emb.cos()[:, None, None, :]
self._sin_cached = emb.sin()[:, None, None, :]
class LLaMAttention(MegatronModule):
def __init__(self, init_method,
output_layer_init_method, layer_number):
super().__init__()
args = get_args()
self.params_dtype = args.params_dtype
mp_world_size = mpu.get_tensor_model_parallel_world_size()
self.n_local_heads = core.utils.divide(
args.num_attention_heads, mp_world_size)
self.head_dim = args.hidden_size // args.num_attention_heads
self.local_num_kv_heads = core.utils.divide(
args.num_kv_heads, mp_world_size)
self.attention_head_type = args.attention_head_type
self.sequence_parallel = args.sequence_parallel
if self.attention_head_type == "multihead":
self.query_key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
3 * args.hidden_size,
gather_output=False,
init_method=init_method,
bias=False,
skip_bias_add=True,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
self.core_attention = CoreAttention(layer_number,
AttnMaskType.causal)
self.cache_k = torch.zeros(
(args.max_position_embeddings, args.micro_batch_size, self.n_local_heads, self.head_dim)
).cuda()
self.cache_v = torch.zeros(
(args.max_position_embeddings, args.micro_batch_size, self.n_local_heads, self.head_dim)
).cuda()
elif self.attention_head_type == "groupedquery":
self.query_key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
args.hidden_size + args.num_kv_heads * 2 * self.head_dim,
gather_output=False,
init_method=init_method,
bias=False,
skip_bias_add=True,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
cached_kv_heads_num = core.utils.divide(
args.num_kv_heads, mp_world_size)
self.core_attention = CoreAttention(layer_number,
AttnMaskType.causal)
# shape: [seq_len, bsz, local_kv_head_num=group_size, key_or_value=1, head_dim]
self.cache_k = torch.zeros(
(args.max_position_embeddings, args.micro_batch_size, cached_kv_heads_num, 1, self.head_dim)
).cuda()
self.cache_v = torch.zeros(
(args.max_position_embeddings, args.micro_batch_size, cached_kv_heads_num, 1, self.head_dim)
).cuda()
else:
raise ValueError(f"attention type was Wrong with {self.attention_head_type} in llama")
self.wo = tensor_parallel.RowParallelLinear(
args.hidden_size,
args.hidden_size,
params_dtype=self.params_dtype,
input_is_parallel=True if mp_world_size > 1 else False,
init_method=output_layer_init_method,
perform_initialization=False,
use_cpu_initialization=True,
bias=False,
skip_bias_add=True)
self.use_flash_attn = args.use_flash_attn
self.core_attention_flash = None
if self.use_flash_attn and flash_attn_varlen_func is not None and rearrange is not None:
self.core_attention_flash = FlashSelfAttention(
causal=True
)
else:
self.use_flash_attn = False
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, attention_mask: Optional[torch.Tensor]):
seqlen, bsz, _ = x.shape
if self.attention_head_type == 'multihead':
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(x)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.n_local_heads,
3 * self.head_dim)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3)
elif self.attention_head_type == 'groupedquery':
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(x)
# [sq, b, ((np + kvnp*2) * hn)] --> [sq, b, local_num_kv_heads, np//kvnp + 2, hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(-1, self.n_local_heads // self.local_num_kv_heads + 2, self.head_dim)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, local_num_kv_heads, np//kvnp + 2, hn] -->
# [sq, b, local_num_kv_heads, np//kvnp, hn], 2 * [sq, b, local_num_kv_heads, 1, hn]
query_layer = mixed_x_layer[:,:,:,:-2]
key_layer = mixed_x_layer[:,:,:,[-2]]
value_layer = mixed_x_layer[:,:,:,[-1]]
self.cache_k = self.cache_k.to(query_layer)
self.cache_v = self.cache_v.to(query_layer)
self.cache_k[start_pos : start_pos + seqlen, :bsz] = key_layer
self.cache_v[start_pos : start_pos + seqlen, :bsz] = value_layer
key_layer = self.cache_k[: start_pos + seqlen, :bsz]
value_layer = self.cache_v[: start_pos + seqlen, :bsz]
if self.attention_head_type == 'groupedquery':
# TODO: now flash-attention only allowed multi-head, so we need copy kv_num_head to q_num_head
sq, b, lkv, np_lkv, hn = query_layer.size()
kv_size = [key_layer.size()[0], b, lkv, np_lkv, hn]
key_layer = torch.broadcast_to(key_layer, kv_size)
value_layer = torch.broadcast_to(value_layer, kv_size)
query_layer, key_layer, value_layer = [x.flatten(2, 3) for x in (query_layer, key_layer, value_layer)]
query_layer, key_layer = RotaryEmbedding.apply_rotary(query_layer, key_layer, freqs_cis=freqs_cis)
if not self.use_flash_attn:
context_layer = self.core_attention(
query_layer, key_layer, value_layer, attention_mask)
else:
q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous()
for x in (query_layer, key_layer, value_layer)]
if not self.sequence_parallel:
with tensor_parallel.get_cuda_rng_tracker().fork():
context_layer = self.core_attention_flash(q, k, v)
else:
context_layer = self.core_attention_flash(q, k, v)
context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
output, _ = self.wo(context_layer)
return output
class LLaMAFeedForward(MegatronModule):
def __init__(
self,init_method, output_layer_init_method
):
super().__init__()
args = get_args()
self.params_dtype = args.params_dtype
if args.inner_hidden_dim is not None and isinstance(args.inner_hidden_dim, int) and args.inner_hidden_dim > 0:
inn_hidden_dim = args.inner_hidden_dim
else:
ffn_expand_rate = 4
inn_hidden_dim = int(2 * (args.hidden_size * ffn_expand_rate) / 3)
# make SwiGLU hidden layer size multiple of large power of 2
multiple_of = 256
inn_hidden_dim = multiple_of * ((inn_hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
inn_hidden_dim,
params_dtype=self.params_dtype,
gather_output=False,
init_method=init_method,
perform_initialization=False,
use_cpu_initialization=True,
skip_bias_add=True,
bias=False,
)
self.w2 = tensor_parallel.RowParallelLinear(
inn_hidden_dim,
args.hidden_size,
params_dtype=self.params_dtype,
input_is_parallel=True if args.tensor_model_parallel_size > 1 else False,
init_method=output_layer_init_method,
perform_initialization=False,
use_cpu_initialization=True,
skip_bias_add=True,
bias=False,
)
self.w3 = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
inn_hidden_dim,
params_dtype=self.params_dtype,
gather_output=False,
init_method=init_method,
perform_initialization=False,
use_cpu_initialization=True,
skip_bias_add=True,
bias=False,
)
def forward(self, x):
part_1, _ = self.w1(x)
part_1 = F.silu(part_1)
part_2, _ = self.w3(x)
final, _ = self.w2(part_1 * part_2)
return final
class LLaMATransformerBlock(MegatronModule):
def __init__(self, init_method,
output_layer_init_method, layer_number):
super().__init__()
args = get_args()
self.attention = LLaMAttention(
init_method=init_method, output_layer_init_method=output_layer_init_method, layer_number=layer_number)
self.feed_forward = LLaMAFeedForward(
init_method, output_layer_init_method
)
self.layer_number = layer_number
# epsilon: 1e-5
self.attention_norm = RMSNorm(args.hidden_size, eps=args.layernorm_epsilon)
self.ffn_norm = RMSNorm(args.hidden_size, eps=args.layernorm_epsilon)
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward.forward(self.ffn_norm(h))
return out
class LLaMATransformer(MegatronModule):
def __init__(self, init_method, output_layer_init_method):
super().__init__()
args = get_args()
self.num_layers = args.num_layers
self.seq_dimension = 0
self.params_dtype = args.params_dtype
self.tok_embeddings = tensor_parallel.VocabParallelEmbedding(
args.padded_vocab_size, args.hidden_size,
init_method=init_method,
params_dtype=self.params_dtype,
use_cpu_initialization=True,
perform_initialization=False
)
self.layers = torch.nn.ModuleList()
for layer_id in range(self.num_layers):
self.layers.append(LLaMATransformerBlock(init_method, output_layer_init_method, layer_id))
self.norm = RMSNorm(args.hidden_size, eps=args.layernorm_epsilon)
# mapping hidden_states to logits value
self.output = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
args.padded_vocab_size,
params_dtype=self.params_dtype,
gather_output=True,
init_method=init_method,
perform_initialization=False,
use_cpu_initialization=True,
skip_bias_add=True,
bias=False
)
if args.rope_linear_scaling_factor > 1:
self.rope = LinearScalingRotaryEmbedding(seq_dimension=self.seq_dimension, rope_theta=args.rope_theta, scaling_factor=args.rope_linear_scaling_factor)
else:
self.rope = RotaryEmbedding(seq_dimension=self.seq_dimension, rope_theta=args.rope_theta)
@torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int, return_hidden=False) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
h = h.transpose(0, 1).contiguous()
self.rope.set_devices_dtype(h)
h_shape = h.shape
h_shape = [s + start_pos if s_i ==0 else s for s_i, s in enumerate(h_shape)]
freqs_cis = self.rope.get_freqs_cis(h_shape)
attention_mask = None
if seqlen > 1:
attention_mask = torch.tril(torch.ones(
(1, seqlen, seqlen), device=tokens.device)).view(
1, 1, seqlen, seqlen).type_as(h)
attention_mask = (attention_mask < 0.5)
hidden_list = []
for layer_id, layer in enumerate(self.layers):
# shape: [seq_len, bsz, hidden_size]
h = layer(h, start_pos, freqs_cis, attention_mask)
if return_hidden and layer_id in {0, int(self.num_layers/3), int(self.num_layers/5 * 4)}:
hidden_list.append(h.float().transpose(0,1).contiguous())
h = self.norm(h)
h = h.transpose(0,1).contiguous()
if return_hidden:
hidden_list.append(h.float())
output, _ = self.output(h)
if return_hidden:
return output.float(), hidden_list, self.output.weight.float()
else:
return output.float()
\ No newline at end of file
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""General utilities."""
import sys
import re
import time
import datetime
import torch
from megatron_mini.model.module import ModelType
from megatron_mini.filter import SensitiveInforRM
from megatron_mini import (
get_args,
)
from collections import OrderedDict
from typing import List, Optional
from typing import Tuple
import numpy as np
import os
from megatron_mini.core import mpu, parallel_state
from megatron_mini.core.tensor_parallel import set_defaults_if_not_set_tensor_model_parallel_attributes
def get_model_for_infer(model_provider_func):
"""Build the model."""
args = get_args()
args.model_type = ModelType.encoder_or_decoder
# Build model.
if (
parallel_state.get_pipeline_model_parallel_world_size() > 1
and args.virtual_pipeline_model_parallel_size is not None
):
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
parallel_state.set_virtual_pipeline_model_parallel_rank(i)
# Set pre_process and post_process only after virtual rank is set.
pre_process = parallel_state.is_pipeline_first_stage()
post_process = parallel_state.is_pipeline_last_stage()
this_model = model_provider_func(
pre_process=pre_process, post_process=post_process
)
model.append(this_model)
else:
pre_process = parallel_state.is_pipeline_first_stage()
post_process = parallel_state.is_pipeline_last_stage()
model = model_provider_func(pre_process=pre_process, post_process=post_process)
if not isinstance(model, list):
model = [model]
for model_module in model:
for param in model_module.parameters():
set_defaults_if_not_set_tensor_model_parallel_attributes(param)
# Print number of parameters.
if parallel_state.get_data_parallel_rank() == 0:
print(
" > number of parameters on (tensor, pipeline) "
"model parallel rank ({}, {}): {}".format(
parallel_state.get_tensor_model_parallel_rank(),
parallel_state.get_pipeline_model_parallel_rank(),
sum(
[
sum(
[
p.ds_numel if hasattr(p, "ds_id") else p.nelement()
for p in model_module.parameters()
]
)
for model_module in model
]
),
),
flush=True, file=sys.stderr
)
return model
LANGUAGE_WRAPPER = {
"c" : "// <AIX-SPE>",
"c++" : "// <AIX-SPE>",
"cpp" : "// <AIX-SPE>",
"c#" : "// <AIX-SPE>",
"csharp" : "// <AIX-SPE>",
"c-sharp" : "// <AIX-SPE>",
"css" : "/* <AIX-SPE> */",
"cuda" : "// <AIX-SPE>",
"dart" : "// <AIX-SPE>",
"lua" : "// <AIX-SPE>",
"objectivec" : "// <AIX-SPE>",
"objective-c" : "// <AIX-SPE>",
"objective-c++": "// <AIX-SPE>",
"python" : "# <AIX-SPE>",
"perl" : "# <AIX-SPE>",
"prolog" : "% <AIX-SPE>",
"swift" : "// <AIX-SPE>",
"lisp" : "; <AIX-SPE>",
"java" : "// <AIX-SPE>",
"scala" : "// <AIX-SPE>",
"tex" : "% <AIX-SPE>",
"vue" : "<!--<AIX-SPE>-->",
"markdown" : "<!--<AIX-SPE>-->",
"html" : "<!--<AIX-SPE>-->",
"php" : "// <AIX-SPE>",
"js" : "// <AIX-SPE>",
"javascript" : "// <AIX-SPE>",
"typescript" : "// <AIX-SPE>",
"go" : "// <AIX-SPE>",
"shell" : "# <AIX-SPE>",
"rust" : "// <AIX-SPE>",
"sql" : "-- <AIX-SPE>",
"kotlin" : "// <AIX-SPE>",
"vb" : "' <AIX-SPE>",
"ruby" : "# <AIX-SPE>",
"pascal" : "// <AIX-SPE>",
"r" : "# <AIX-SPE>",
"fortran" : "!<AIX-SPE>",
"lean" : "-- <AIX-SPE>",
"matlab" : "% <AIX-SPE>",
"delphi" : "{<AIX-SPE>}",
"scheme" : "; <AIX-SPE>",
"basic" : "' <AIX-SPE>",
"assembly" : "; <AIX-SPE>",
"groovy" : "// <AIX-SPE>",
"abap" : "* <AIX-SPE>",
"gdscript" : "# <AIX-SPE>",
"haskell" : "-- <AIX-SPE>",
"julia" : "# <AIX-SPE>",
"elixir" : "# <AIX-SPE>",
"excel" : "' <AIX-SPE>",
"clojure" : "; <AIX-SPE>",
"actionscript" : "// <AIX-SPE>",
"solidity" : "// <AIX-SPE>",
"powershell" : "# <AIX-SPE>",
"erlang" : "% <AIX-SPE>",
"cobol" : "// <AIX-SPE>",
"alloy" : "/* <AIX-SPE> */",
"awk" : "// <AIX-SPE>",
"thrift" : "/* <AIX-SPE> */",
"sparql" : "# <AIX-SPE>",
"augeas" : "// <AIX-SPE>",
"cmake" : "# <AIX-SPE>",
"f-sharp" : "// <AIX-SPE>",
"stan" : "// <AIX-SPE>",
"isabelle" : "(*<AIX-SPE>*)",
"dockerfile" : "# <AIX-SPE>",
"rmarkdown" : "# <AIX-SPE>",
"literate-agda": "-- <AIX-SPE>",
"tcl" : "// <AIX-SPE>",
"glsl" : "// <AIX-SPE>",
"antlr" : "// <AIX-SPE>",
"verilog" : "// <AIX-SPE>",
"racket" : "; <AIX-SPE>",
"standard-ml" : "(*<AIX-SPE>*)",
"elm" : "-- <AIX-SPE>",
"yaml" : "# <AIX-SPE>",
"smalltalk" : "'' <AIX-SPE>",
"ocaml" : "(*<AIX-SPE>*)",
"idris" : "-- <AIX-SPE>",
"visual-basic" : "' <AIX-SPE>",
"protocol-buffer": "// <AIX-SPE>",
"bluespec" : "// <AIX-SPE>",
"applescript" : "-- <AIX-SPE>",
"makefile" : "# <AIX-SPE>",
"tcsh" : "# <AIX-SPE>",
"maple" : "# <AIX-SPE>",
"systemverilog": "// <AIX-SPE>",
"literate-coffeescript": "# <AIX-SPE>",
"vhdl" : "-- <AIX-SPE>",
"restructuredtext": ".. <AIX-SPE>",
"sas" : "* <AIX-SPE>",
"literate-haskell": "> <AIX-SPE>",
"java-server-pages": "// <AIX-SPE>",
"coffeescript" : "# <AIX-SPE>",
"emacs-lisp" : "; <AIX-SPE>",
"mathematica" : "// <AIX-SPE>",
"xslt" : "<!--<AIX-SPE>-->",
"zig" : "// <AIX-SPE>",
"common-lisp" : "; <AIX-SPE>",
"stata" : "* <AIX-SPE>",
"agda" : "-- <AIX-SPE>",
"ada" : "-- <AIX-SPE>",
"jsx" : "// <AIX-SPE>",
"tsx" : "// <AIX-SPE>",
}
EXT2LANG = {
".abap": "abap",
".ash": "ags script",
".ampl": "ampl",
".g4": "antlr",
".apib": "api blueprint",
".apl": "apl",
".dyalog": "apl",
".asp": "asp",
".asax": "asp",
".ascx": "asp",
".ashx": "asp",
".asmx": "asp",
".aspx": "asp",
".axd": "asp",
".dats": "ats",
".hats": "ats",
".sats": "ats",
".as": "actionscript",
".adb": "ada",
".ada": "ada",
".ads": "ada",
".agda": "agda",
".als": "alloy",
".apacheconf": "apacheconf",
".vhost": "apacheconf",
".applescript": "applescript",
".scpt": "applescript",
".arc": "arc",
".ino": "arduino",
".asciidoc": "asciidoc",
".adoc": "asciidoc",
".aj": "aspectj",
".asm": "assembly",
".a51": "assembly",
".nasm": "assembly",
".aug": "augeas",
".ahk": "autohotkey",
".ahkl": "autohotkey",
".au3": "autoit",
".awk": "awk",
".auk": "awk",
".gawk": "awk",
".mawk": "awk",
".nawk": "awk",
".bat": "batchfile",
".cmd": "batchfile",
".befunge": "befunge",
".bison": "bison",
".bb": "bitbake",
".decls": "blitzbasic",
".bmx": "blitzmax",
".bsv": "bluespec",
".boo": "boo",
".bf": "brainfuck",
".brs": "brightscript",
".bro": "bro",
".c": "c",
".cats": "c",
".h": "c++",
".idc": "c",
".w": "c",
".cs": "c#",
".cake": "c#",
".cshtml": "c#",
".csx": "c#",
".cpp": "c++",
".c++": "c++",
".cc": "c++",
".cp": "c++",
".cxx": "c++",
".h++": "c++",
".hh": "c++",
".hpp": "c++",
".hxx": "c++",
".inl": "c++",
".ipp": "c++",
".tcc": "c++",
".tpp": "c++",
".C": "c++",
".H": "c++",
".c-objdump": "c-objdump",
".chs": "c2hs haskell",
".clp": "clips",
".cmake": "cmake",
".cmake.in": "cmake",
".cob": "cobol",
".cbl": "cobol",
".ccp": "cobol",
".cobol": "cobol",
".cpy": "cobol",
".css": "css",
".csv": "csv",
".capnp": "cap'n proto",
".mss": "cartocss",
".ceylon": "ceylon",
".chpl": "chapel",
".ck": "chuck",
".cirru": "cirru",
".clw": "clarion",
".icl": "clean",
".dcl": "clean",
".click": "click",
".clj": "clojure",
".boot": "clojure",
".cl2": "clojure",
".cljc": "clojure",
".cljs": "clojure",
".cljs.hl": "clojure",
".cljscm": "clojure",
".cljx": "clojure",
".hic": "clojure",
".coffee": "coffeescript",
"._coffee": "coffeescript",
".cjsx": "coffeescript",
".cson": "coffeescript",
".iced": "coffeescript",
".cfm": "coldfusion",
".cfml": "coldfusion",
".cfc": "coldfusion cfc",
".lisp": "common lisp",
".asd": "common lisp",
".lsp": "common lisp",
".ny": "common lisp",
".podsl": "common lisp",
".sexp": "common lisp",
".cps": "component pascal",
".coq": "coq",
".cppobjdump": "cpp-objdump",
".c++-objdump": "cpp-objdump",
".c++objdump": "cpp-objdump",
".cpp-objdump": "cpp-objdump",
".cxx-objdump": "cpp-objdump",
".creole": "creole",
".cr": "crystal",
".csd": "csound",
".feature": "cucumber",
".cu": "cuda",
".cuh": "cuda",
".cy": "cycript",
".pyx": "cython",
".pxd": "cython",
".pxi": "cython",
".di": "d",
".d-objdump": "d-objdump",
".com": "digital command language",
".dm": "dm",
".zone": "dns zone",
".arpa": "dns zone",
".darcspatch": "darcs patch",
".dpatch": "darcs patch",
".dart": "dart",
".diff": "diff",
".patch": "diff",
".dockerfile": "dockerfile",
"Dockerfile": "dockerfile",
".djs": "dogescript",
".dylan": "dylan",
".dyl": "dylan",
".intr": "dylan",
".lid": "dylan",
".E": "e",
".ecl": "ecl",
".eclxml": "ecl",
".sch": "eagle",
".brd": "eagle",
".epj": "ecere projects",
".e": "eiffel",
".ex": "elixir",
".exs": "elixir",
".elm": "elm",
".el": "emacs lisp",
".emacs": "emacs lisp",
".emacs.desktop": "emacs lisp",
".em": "emberscript",
".emberscript": "emberscript",
".erl": "erlang",
".escript": "erlang",
".hrl": "erlang",
".xrl": "erlang",
".yrl": "erlang",
".fs": "f#",
".fsi": "f#",
".fsx": "f#",
".flux": "flux",
".f90": "fortran",
".f": "fortran",
".f03": "fortran",
".f08": "fortran",
".f77": "fortran",
".f95": "fortran",
".for": "fortran",
".fpp": "fortran",
".factor": "factor",
".fy": "fancy",
".fancypack": "fancy",
".fan": "fantom",
".eam.fs": "formatted",
".fth": "forth",
".4th": "forth",
".forth": "forth",
".frt": "forth",
".ftl": "freemarker",
".g": "g-code",
".gco": "g-code",
".gcode": "g-code",
".gms": "gams",
".gap": "gap",
".gi": "gap",
".s": "gas",
".gd": "gdscript",
".glsl": "glsl",
".fp": "glsl",
".frag": "glsl",
".frg": "glsl",
".fsh": "glsl",
".fshader": "glsl",
".geo": "glsl",
".geom": "glsl",
".glslv": "glsl",
".gshader": "glsl",
".shader": "glsl",
".vert": "glsl",
".vrx": "glsl",
".vsh": "glsl",
".vshader": "glsl",
".kid": "genshi",
".ebuild": "gentoo ebuild",
".eclass": "gentoo eclass",
".po": "gettext catalog",
".pot": "gettext catalog",
".glf": "glyph",
".gp": "gnuplot",
".gnu": "gnuplot",
".gnuplot": "gnuplot",
".plot": "gnuplot",
".plt": "gnuplot",
".go": "go",
".golo": "golo",
".gst": "gosu",
".gsx": "gosu",
".vark": "gosu",
".grace": "grace",
".gradle": "gradle",
".gf": "grammatical framework",
".graphql": "graphql",
".dot": "graphviz (dot)",
".gv": "graphviz (dot)",
".man": "groff",
".1": "groff",
".1in": "groff",
".1m": "groff",
".1x": "groff",
".2": "groff",
".3": "groff",
".3in": "groff",
".3m": "groff",
".3qt": "groff",
".3x": "groff",
".4": "groff",
".5": "groff",
".6": "groff",
".7": "groff",
".8": "groff",
".9": "groff",
".me": "groff",
".rno": "groff",
".roff": "groff",
".groovy": "groovy",
".grt": "groovy",
".gtpl": "groovy",
".gvy": "groovy",
".gsp": "groovy server pages",
".hcl": "hcl",
".tf": "hcl",
".hlsl": "hlsl",
".fxh": "hlsl",
".hlsli": "hlsl",
".html": "html",
".htm": "html",
".html.hl": "html",
".xht": "html",
".xhtml": "html",
".mustache": "html+django",
".jinja": "html+django",
".eex": "html+eex",
".erb": "html+erb",
".erb.deface": "html+erb",
".phtml": "html+php",
".http": "http",
".haml": "haml",
".haml.deface": "haml",
".handlebars": "handlebars",
".hbs": "handlebars",
".hb": "harbour",
".hs": "haskell",
".hsc": "haskell",
".hx": "haxe",
".hxsl": "haxe",
".hy": "hy",
".dlm": "idl",
".ipf": "igor pro",
".ini": "ini",
".cfg": "ini",
".prefs": "ini",
".properties": "ini",
".irclog": "irc log",
".weechatlog": "irc log",
".idr": "idris",
".lidr": "idris",
".ni": "inform 7",
".i7x": "inform 7",
".iss": "inno setup",
".io": "io",
".ik": "ioke",
".thy": "isabelle",
".ijs": "j",
".flex": "jflex",
".jflex": "jflex",
".json": "json",
".geojson": "json",
".lock": "json",
".topojson": "json",
".json5": "json5",
".jsonld": "jsonld",
".jq": "jsoniq",
".jsx": "jsx",
".jade": "jade",
".j": "jasmin",
".java": "java",
".jsp": "java server pages",
".js": "javascript",
"._js": "javascript",
".bones": "javascript",
".es6": "javascript",
".jake": "javascript",
".jsb": "javascript",
".jscad": "javascript",
".jsfl": "javascript",
".jsm": "javascript",
".jss": "javascript",
".njs": "javascript",
".pac": "javascript",
".sjs": "javascript",
".ssjs": "javascript",
".xsjs": "javascript",
".xsjslib": "javascript",
".jl": "julia",
".ipynb": "jupyter notebook",
".krl": "krl",
".kicad_pcb": "kicad",
".kit": "kit",
".kt": "kotlin",
".ktm": "kotlin",
".kts": "kotlin",
".lfe": "lfe",
".ll": "llvm",
".lol": "lolcode",
".lsl": "lsl",
".lslp": "lsl",
".lvproj": "labview",
".lasso": "lasso",
".las": "lasso",
".lasso8": "lasso",
".lasso9": "lasso",
".ldml": "lasso",
".latte": "latte",
".lean": "lean",
".hlean": "lean",
".less": "less",
".lex": "lex",
".ly": "lilypond",
".ily": "lilypond",
".ld": "linker script",
".lds": "linker script",
".liquid": "liquid",
".lagda": "literate agda",
".litcoffee": "literate coffeescript",
".lhs": "literate haskell",
".ls": "livescript",
"._ls": "livescript",
".xm": "logos",
".x": "logos",
".xi": "logos",
".lgt": "logtalk",
".logtalk": "logtalk",
".lookml": "lookml",
".lua": "lua",
".nse": "lua",
".pd_lua": "lua",
".rbxs": "lua",
".wlua": "lua",
".mumps": "m",
".m4": "m4",
".mcr": "maxscript",
".mtml": "mtml",
".muf": "muf",
".mak": "makefile",
".mk": "makefile",
".mkfile": "makefile",
"Makefile": "makefile",
".mako": "mako",
".mao": "mako",
".mpl": "maple",
".md": "markdown",
".markdown": "markdown",
".mkd": "markdown",
".mkdn": "markdown",
".mkdown": "markdown",
".ron": "markdown",
".mask": "mask",
".mathematica": "mathematica",
".cdf": "mathematica",
".ma": "mathematica",
".mt": "mathematica",
".nb": "mathematica",
".nbp": "mathematica",
".wl": "mathematica",
".wlt": "mathematica",
".matlab": "matlab",
".maxpat": "max",
".maxhelp": "max",
".maxproj": "max",
".mxt": "max",
".pat": "max",
".mediawiki": "mediawiki",
".wiki": "mediawiki",
".metal": "metal",
".minid": "minid",
".druby": "mirah",
".duby": "mirah",
".mir": "mirah",
".mirah": "mirah",
".mo": "modelica",
".mms": "module management system",
".mmk": "module management system",
".monkey": "monkey",
".moon": "moonscript",
".myt": "myghty",
".nsi": "nsis",
".nsh": "nsis",
".axs": "netlinx",
".axi": "netlinx",
".axs.erb": "netlinx+erb",
".axi.erb": "netlinx+erb",
".nlogo": "netlogo",
".nginxconf": "nginx",
".nim": "nimrod",
".nimrod": "nimrod",
".ninja": "ninja",
".nit": "nit",
".nix": "nix",
".nu": "nu",
".numpy": "numpy",
".numpyw": "numpy",
".numsc": "numpy",
".ml": "ocaml",
".eliom": "ocaml",
".eliomi": "ocaml",
".ml4": "ocaml",
".mli": "ocaml",
".mll": "ocaml",
".mly": "ocaml",
".objdump": "objdump",
".mm": "objective-c++",
".sj": "objective-j",
".oct": "octave",
".omgrofl": "omgrofl",
".opa": "opa",
".opal": "opal",
".cl": "opencl",
".opencl": "opencl",
".p": "openedge abl",
".scad": "openscad",
".org": "org",
".ox": "ox",
".oxh": "ox",
".oxo": "ox",
".oxygene": "oxygene",
".oz": "oz",
".pwn": "pawn",
".php": "php",
".aw": "php",
".ctp": "php",
".php3": "php",
".php4": "php",
".php5": "php",
".phps": "php",
".phpt": "php",
".pov": "pov-ray sdl",
".pan": "pan",
".psc": "papyrus",
".parrot": "parrot",
".pasm": "parrot assembly",
".pir": "parrot internal representation",
".pas": "pascal",
".dfm": "pascal",
".dpr": "pascal",
".lpr": "pascal",
".pl": "perl",
".al": "perl",
".perl": "perl",
".ph": "perl",
".plx": "perl",
".pm": "perl",
".psgi": "perl",
".t": "perl",
".6pl": "perl6",
".6pm": "perl6",
".nqp": "perl6",
".p6": "perl6",
".p6l": "perl6",
".p6m": "perl6",
".pl6": "perl6",
".pm6": "perl6",
".pkl": "pickle",
".pig": "piglatin",
".pike": "pike",
".pmod": "pike",
".pod": "pod",
".pogo": "pogoscript",
".pony": "pony",
".ps": "postscript",
".eps": "postscript",
".ps1": "powershell",
".psd1": "powershell",
".psm1": "powershell",
".pde": "processing",
".prolog": "prolog",
".yap": "prolog",
".spin": "propeller spin",
".proto": "protocol buffer",
".pub": "public key",
".pd": "pure data",
".pb": "purebasic",
".pbi": "purebasic",
".purs": "purescript",
".py": "python",
".bzl": "python",
".gyp": "python",
".lmi": "python",
".pyde": "python",
".pyp": "python",
".pyt": "python",
".pyw": "python",
".tac": "python",
".wsgi": "python",
".xpy": "python",
".pytb": "python traceback",
".qml": "qml",
".qbs": "qml",
".pri": "qmake",
".r": "r",
".rd": "r",
".rsx": "r",
".raml": "raml",
".rdoc": "rdoc",
".rbbas": "realbasic",
".rbfrm": "realbasic",
".rbmnu": "realbasic",
".rbres": "realbasic",
".rbtbar": "realbasic",
".rbuistate": "realbasic",
".rhtml": "rhtml",
".rmd": "rmarkdown",
".rkt": "racket",
".rktd": "racket",
".rktl": "racket",
".scrbl": "racket",
".rl": "ragel in ruby host",
".raw": "raw token data",
".reb": "rebol",
".r2": "rebol",
".r3": "rebol",
".rebol": "rebol",
".red": "red",
".reds": "red",
".cw": "redcode",
".rpy": "ren'py",
".rsh": "renderscript",
".robot": "robotframework",
".rg": "rouge",
".rb": "ruby",
".builder": "ruby",
".gemspec": "ruby",
".god": "ruby",
".irbrc": "ruby",
".jbuilder": "ruby",
".mspec": "ruby",
".podspec": "ruby",
".rabl": "ruby",
".rake": "ruby",
".rbuild": "ruby",
".rbw": "ruby",
".rbx": "ruby",
".ru": "ruby",
".ruby": "ruby",
".thor": "ruby",
".watchr": "ruby",
".rs": "rust",
".rs.in": "rust",
".sas": "sas",
".scss": "scss",
".smt2": "smt",
".smt": "smt",
".sparql": "sparql",
".rq": "sparql",
".sqf": "sqf",
".hqf": "sqf",
".pls": "sql",
".pck": "sql",
".pkb": "sql",
".pks": "sql",
".plb": "sql",
".plsql": "sql",
".sql": "sql",
".cql": "sql",
".ddl": "sql",
".prc": "sql",
".tab": "sql",
".udf": "sql",
".viw": "sql",
".db2": "sql",
".ston": "ston",
".svg": "svg",
".sage": "sage",
".sagews": "sage",
".sls": "saltstack",
".sass": "sass",
".scala": "scala",
".sbt": "scala",
".scaml": "scaml",
".scm": "scheme",
".sld": "scheme",
".sps": "scheme",
".ss": "scheme",
".sci": "scilab",
".sce": "scilab",
".self": "self",
".sh": "shell",
".bash": "shell",
".bats": "shell",
".command": "shell",
".ksh": "shell",
".sh.in": "shell",
".tmux": "shell",
".tool": "shell",
".zsh": "shell",
".sh-session": "shellsession",
".shen": "shen",
".sl": "slash",
".slim": "slim",
".smali": "smali",
".st": "smalltalk",
".tpl": "smarty",
".sol": "solidity",
".sp": "sourcepawn",
".sma": "sourcepawn",
".nut": "squirrel",
".stan": "stan",
".ML": "standard ml",
".fun": "standard ml",
".sig": "standard ml",
".sml": "standard ml",
".do": "stata",
".ado": "stata",
".doh": "stata",
".ihlp": "stata",
".mata": "stata",
".matah": "stata",
".sthlp": "stata",
".styl": "stylus",
".scd": "supercollider",
".swift": "swift",
".sv": "systemverilog",
".svh": "systemverilog",
".vh": "systemverilog",
".toml": "toml",
".txl": "txl",
".tcl": "tcl",
".adp": "tcl",
".tm": "tcl",
".tcsh": "tcsh",
".csh": "tcsh",
".tex": "tex",
".aux": "tex",
".bbx": "tex",
".bib": "tex",
".cbx": "tex",
".dtx": "tex",
".ins": "tex",
".lbx": "tex",
".ltx": "tex",
".mkii": "tex",
".mkiv": "tex",
".mkvi": "tex",
".sty": "tex",
".toc": "tex",
".tea": "tea",
".txt": "text",
".no": "text",
".textile": "textile",
".thrift": "thrift",
".tu": "turing",
".ttl": "turtle",
".twig": "twig",
".ts": "typescript",
".tsx": "tsx",
".upc": "unified parallel c",
".anim": "unity3d asset",
".asset": "unity3d asset",
".mat": "unity3d asset",
".meta": "unity3d asset",
".prefab": "unity3d asset",
".unity": "unity3d asset",
".uno": "uno",
".uc": "unrealscript",
".ur": "urweb",
".urs": "urweb",
".vcl": "vcl",
".vhdl": "vhdl",
".vhd": "vhdl",
".vhf": "vhdl",
".vhi": "vhdl",
".vho": "vhdl",
".vhs": "vhdl",
".vht": "vhdl",
".vhw": "vhdl",
".vala": "vala",
".vapi": "vala",
".veo": "verilog",
".vim": "viml",
".vb": "visual basic",
".bas": "visual basic",
".frm": "visual basic",
".frx": "visual basic",
".vba": "visual basic",
".vbhtml": "visual basic",
".vbs": "visual basic",
".volt": "volt",
".vue": "vue",
".owl": "web ontology language",
".wat": "webassembly",
".webidl": "webidl",
".x10": "x10",
".xc": "xc",
".xml": "xml",
".ant": "xml",
".axml": "xml",
".ccxml": "xml",
".clixml": "xml",
".cproject": "xml",
".csl": "xml",
".csproj": "xml",
".ct": "xml",
".dita": "xml",
".ditamap": "xml",
".ditaval": "xml",
".dll.config": "xml",
".dotsettings": "xml",
".filters": "xml",
".fsproj": "xml",
".fxml": "xml",
".glade": "xml",
".grxml": "xml",
".iml": "xml",
".ivy": "xml",
".jelly": "xml",
".jsproj": "xml",
".kml": "xml",
".launch": "xml",
".mdpolicy": "xml",
".mxml": "xml",
".nproj": "xml",
".nuspec": "xml",
".odd": "xml",
".osm": "xml",
".plist": "xml",
".props": "xml",
".ps1xml": "xml",
".psc1": "xml",
".pt": "xml",
".rdf": "xml",
".rss": "xml",
".scxml": "xml",
".srdf": "xml",
".storyboard": "xml",
".stTheme": "xml",
".sublime-snippet": "xml",
".targets": "xml",
".tmCommand": "xml",
".tml": "xml",
".tmLanguage": "xml",
".tmPreferences": "xml",
".tmSnippet": "xml",
".tmTheme": "xml",
".ui": "xml",
".urdf": "xml",
".ux": "xml",
".vbproj": "xml",
".vcxproj": "xml",
".vssettings": "xml",
".vxml": "xml",
".wsdl": "xml",
".wsf": "xml",
".wxi": "xml",
".wxl": "xml",
".wxs": "xml",
".x3d": "xml",
".xacro": "xml",
".xaml": "xml",
".xib": "xml",
".xlf": "xml",
".xliff": "xml",
".xmi": "xml",
".xml.dist": "xml",
".xproj": "xml",
".xsd": "xml",
".xul": "xml",
".zcml": "xml",
".xsp-config": "xpages",
".xsp.metadata": "xpages",
".xpl": "xproc",
".xproc": "xproc",
".xquery": "xquery",
".xq": "xquery",
".xql": "xquery",
".xqm": "xquery",
".xqy": "xquery",
".xs": "xs",
".xslt": "xslt",
".xsl": "xslt",
".xojo_code": "xojo",
".xojo_menu": "xojo",
".xojo_report": "xojo",
".xojo_script": "xojo",
".xojo_toolbar": "xojo",
".xojo_window": "xojo",
".xtend": "xtend",
".yml": "yaml",
".reek": "yaml",
".rviz": "yaml",
".sublime-syntax": "yaml",
".syntax": "yaml",
".yaml": "yaml",
".yaml-tmlanguage": "yaml",
".yang": "yang",
".y": "yacc",
".yacc": "yacc",
".yy": "yacc",
".zep": "zephir",
".zig": "zig",
".zimpl": "zimpl",
".zmpl": "zimpl",
".zpl": "zimpl",
".desktop": "desktop",
".desktop.in": "desktop",
".ec": "ec",
".eh": "ec",
".edn": "edn",
".fish": "fish",
".mu": "mupad",
".nc": "nesc",
".ooc": "ooc",
".rst": "restructuredtext",
".rest": "restructuredtext",
".rest.txt": "restructuredtext",
".rst.txt": "restructuredtext",
".wisp": "wisp",
".prg": "xbase",
".prw": "xbase"
}
LANGUAGE_TAG = {
"c" : "// the code file is written by C",
"c++" : "// the code file is written by C++",
"cpp" : "// the code file is written by C++",
"c#" : "// the code file is written by C#",
"csharp" : "// the code file is written by C#",
"c-sharp" : "// the code file is written by C#",
"css" : "/* the code file is written by CSS */",
"cuda" : "// the code file is written by Cuda",
"dart" : "// the code file is written by Dart",
"lua" : "// the code file is written by Lua",
"objectivec" : "// the code file is written by Objective-C",
"objective-c" : "// the code file is written by Objective-C",
"objective-c++": "// the code file is written by Objective-C++",
"python" : "# the code file is written by Python",
"perl" : "# the code file is written by Perl",
"prolog" : "% the code file is written by Prolog",
"swift" : "// the code file is written by swift",
"lisp" : "; the code file is written by Lisp",
"java" : "// the code file is written by Java",
"scala" : "// the code file is written by Scala",
"tex" : "% the code file is written by TeX",
"vue" : "<!--the code file is written by Vue-->",
"markdown" : "<!--the code file is written by Markdown-->",
"html" : "<!--the code file is written by HTML-->",
"php" : "// the code file is written by PHP",
"js" : "// the code file is written by JavaScript",
"javascript" : "// the code file is written by JavaScript",
"typescript" : "// the code file is written by TypeScript",
"go" : "// the code file is written by Go",
"shell" : "# the code file is written by Shell",
"rust" : "// the code file is written by Rust",
"sql" : "-- the code file is written by SQL",
"kotlin" : "// the code file is written by Kotlin",
"vb" : "' the code file is written by Visual Basic",
"ruby" : "# the code file is written by Ruby",
"pascal" : "// the code file is written by Pascal",
"r" : "# the code file is written by R",
"fortran" : "!the code file is written by Fortran",
"lean" : "-- the code file is written by Lean",
"matlab" : "% the code file is written by Matlab",
"delphi" : "{the code file is written by Delphi}",
"scheme" : "; the code file is written by Scheme",
"basic" : "' the code file is written by Basic",
"assembly" : "; the code file is written by Assembly",
"groovy" : "// the code file is written by Groovy",
"abap" : "* the code file is written by Abap",
"gdscript" : "# the code file is written by GDScript",
"haskell" : "-- the code file is written by Haskell",
"julia" : "# the code file is written by Julia",
"elixir" : "# the code file is written by Elixir",
"excel" : "' the code file is written by Excel",
"clojure" : "; the code file is written by Clojure",
"actionscript" : "// the code file is written by ActionScript",
"solidity" : "// the code file is written by Solidity",
"powershell" : "# the code file is written by PowerShell",
"erlang" : "% the code file is written by Erlang",
"cobol" : "// the code file is written by Cobol",
"alloy" : "/* the code file is written by Alloy */",
"awk" : "// the code file is written by AWK",
"thrift" : "/* the code file is written by Thrift */",
"sparql" : "# the code file is written by SPARQL",
"augeas" : "// the code file is written by Augeas",
"cmake" : "# the code file is written by CMake",
"f-sharp" : "// the code file is written by F#",
"stan" : "// the code file is written by Stan",
"isabelle" : "(*the code file is written by Isabelle*)",
"dockerfile" : "# the code file is written by Dockerfile",
"rmarkdown" : "# the code file is written by RMarkdown",
"literate-agda": "-- the code file is written by Literate Agda",
"tcl" : "// the code file is written by Augeas",
"glsl" : "// the code file is written by GLSL",
"antlr" : "// the code file is written by ANTLR",
"verilog" : "// the code file is written by Verilog",
"racket" : "; the code file is written by Racket",
"standard-ml" : "(*the code file is written byStandard ML*)",
"elm" : "-- the code file is written by Elm",
"yaml" : "# the code file is written by YAML",
"smalltalk" : "'' the code file is written by Smalltalk",
"ocaml" : "(*the code file is written by OCaml*)",
"idris" : "-- the code file is written by Idris",
"visual-basic" : "' the code file is written by Visual Basic",
"protocol-buffer": "// the code file is written by Protocol Buffer",
"bluespec" : "// the code file is written by Bluespec",
"applescript" : "-- the code file is written by AppleScript",
"makefile" : "# the code file is written by Makefile",
"tcsh" : "# the code file is written by TCSH",
"maple" : "# the code file is written by Maple",
"systemverilog": "// the code file is written by SystemVerilog",
"literate-coffeescript": "# the code file is written by Literate CoffeeScript",
"vhdl" : "-- the code file is written by VHDL",
"restructuredtext": ".. the code file is written by reStructuredText",
"sas" : "* the code file is written by SAS",
"literate-haskell": "> the code file is written by Literate Haskell",
"java-server-pages": "// the code file is written by Java Server Pages",
"coffeescript" : "# the code file is written by CoffeeScript",
"emacs-lisp" : "; the code file is written by Emacs Lisp",
"mathematica" : "// the code file is written by Mathematica",
"xslt" : "<!--the code file is written by XSLT-->",
"zig" : "// the code file is written by Zig",
"common-lisp" : "; the code file is written by Common Lisp",
"stata" : "* the code file is written by Stata",
"agda" : "-- the code file is written by Agda",
"ada" : "-- the code file is written by Ada",
"jsx" : "// the code file is written by JSX",
"tsx" : "// the code file is written by TypeScript with JSX",
}
class Tokenizer:
def __init__(self, rank: int = 0, model_path: str = "", logger_info=True):
# reload tokenizer
from sentencepiece import SentencePieceProcessor
assert os.path.isfile(model_path), model_path
self.sp_model = SentencePieceProcessor(model_file=model_path)
if rank == 0 and logger_info:
print(f"Reloaded SentencePiece model from {model_path}", flush=True)
# BOS / EOS token IDs
self.n_words: int = self.sp_model.vocab_size()
self.bos_id: int = self.sp_model.bos_id()
self.eos_id: int = self.sp_model.eos_id()
self.pad_id: int = self.sp_model.pad_id()
if self.pad_id < 0:
self.pad_id = self.eos_id
# token IDs for special infilling tokens
self.prefix_id: Optional[int] = self.sp_model.piece_to_id("▁<AIX-SPAN-PRE>") or None
self.middle_id: Optional[int] = self.sp_model.piece_to_id("▁<AIX-SPAN-MIDDLE>") or None
self.suffix_id: Optional[int] = self.sp_model.piece_to_id("▁<AIX-SPAN-POST>") or None
self.prefix_tok_id = self.prefix_id
self.suffix_tok_id = self.suffix_id
self.middle_tok_id = self.middle_id
self.pad_tok_id = self.pad_id
self.extension_pattern = re.compile(r"(\.\w+)$")
self.unk_token = "☺"
self.unk_token_length = len(self.sp_model.encode(str(self.unk_token)))
self.user_id: Optional[int] = self.sp_model.piece_to_id("▁<AIX-USER>") or None
self.assistant_id: Optional[int] = self.sp_model.piece_to_id("▁<AIX-ASSISTANT>") or None
self.eot_id: Optional[int] = self.sp_model.piece_to_id("▁<AIX-END-TURN>") or None
self.end_token_set = {
self.bos_id, self.eos_id, self.pad_id, self.eot_id,
self.prefix_id, self.middle_id, self.suffix_id
}
self.is_security = SensitiveInforRM()
if rank == 0 and logger_info:
print(
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id} - PAD ID: {self.pad_id} "
f"- PRE ID: {self.prefix_id} - MID ID: {self.middle_id} - SUF ID: {self.suffix_id} - EOT ID: {self.eot_id}",
flush=True
)
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
def __encode(self, s: str, path: str = None, is_fim: bool = False) -> List[int]:
p = ""
if path is not None and len(path) > 0:
extension = self.extension_pattern.search(path)
if extension is not None:
extension = extension.groups()[0]
lang = EXT2LANG.get(extension, "")
des = LANGUAGE_TAG.get(lang, "")
if len(des) > 0:
s = des + "\n" + s
des = LANGUAGE_WRAPPER.get(lang, "")
if len(des) > 0 and "<AIX-SPE>" in des:
p = des.replace("<AIX-SPE>", f"the file path is: {path}") + "\n"
if is_fim:
tokens = self.sp_model.encode(self.unk_token + p + s)
return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
else:
return self.sp_model.encode(p + s)
def encode(self, code_string: str, later_code: str, file_path: str) -> List[int]:
start = time.time()
_sequerity = True
for i in [code_string, later_code, file_path]:
if not self.is_security.is_security(i):
_sequerity = False
break
print(f"Done inputs checking with {(time.time()-start) * 1000:.2f}ms", flush=True)
if not _sequerity:
return []
assert len(code_string) > 0
if len(later_code) == 0:
t = self.__encode(code_string, file_path, False)
t = [self.bos_id] + t
else:
t = [self.bos_id, self.prefix_tok_id, self.suffix_tok_id] + self.__encode(later_code, None, True)
t += [self.middle_tok_id] + self.__encode(code_string, file_path, False)
return t
def decode(self, t: List[int], is_fim: bool = False) -> str:
if not isinstance(t, List):
raise ValueError
if is_fim:
return self.sp_model.decode([self.sp_model.piece_to_id("☺")] + t)[1:]
else:
return self.sp_model.decode(t)
\ No newline at end of file
accelerate==0.27.1
datasets>=2.16.1
bitsandbytes==0.41.3
peft==0.8.2
trl==0.7.10
wandb==0.16.3
huggingface_hub==0.20.3
\ No newline at end of file
import torch
import sys
from hf_mini.utils import input_wrapper
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda" # the device to load the model onto
tokenizer = AutoTokenizer.from_pretrained("aiXcoder/aixcoder-7b-base")
model = AutoModelForCausalLM.from_pretrained("aiXcoder/aixcoder-7b-base", torch_dtype=torch.bfloat16)
inputs = input_wrapper(
tokenizer=tokenizer,
code_string="# 快速排序算法",
later_code="\n",
path="test.py",
)
if inputs is None:
sys.exit()
inputs = inputs.to(device)
model.to(device)
outputs = model.generate(**inputs, max_new_tokens=256)
print(tokenizer.decode(outputs[0], skip_special_tokens=False))
\ No newline at end of file
import sys
import os
import pathlib
import torch
import traceback
import numpy as np
from typing import List, Tuple
from megatron_mini import get_args
from megatron_mini.initialize import initialize_megatron
from megatron_mini.model import LLaMAModel
from megatron_mini.utils import get_model_for_infer, Tokenizer
def print_rank_0(message):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True, file=sys.stderr)
else:
print(message, flush=True, file=sys.stderr)
def add_code_generation_args(parser):
"""Code generation arguments."""
group = parser.add_argument_group(title="code generation")
group.add_argument(
"--padded_vocab_size",
type=int,
default=40000,
help="Start id for whitespace encoding",
)
group.add_argument("--model_dir", type=str, default="")
group.add_argument("--model_name", type=str, default="aix3-7b-base")
return parser
class Predictor(object):
def __init__(self, args):
self.args = args
self.checkpoint_head_hash: str = ""
self.np_rand = np.random.RandomState(seed=1414)
# build predictor
self.tokenizer = self.create_tokenizer()
self.dtype = torch.float32
if self.args.bf16:
self.dtype = torch.bfloat16
elif self.args.fp16:
self.dtype = torch.half
self.predictor = self.create_predictor()
if torch.distributed.is_initialized():
torch.distributed.barrier()
@staticmethod
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0("Building Codemodel ...")
model = LLaMAModel(parallel_output=False)
return model
@staticmethod
def pad_batch(tokens_id, max_seq_len=2048):
"""
pad_batch was used by syncing token_ids
"""
tokens_id = np.reshape(tokens_id, [1, -1])
context_length = tokens_id.shape[-1]
assert context_length <= max_seq_len, f"{context_length}, {max_seq_len}"
if context_length < max_seq_len:
tokens_id = np.concatenate([tokens_id, np.zeros(shape=[1, max_seq_len-context_length], dtype=tokens_id.dtype)], axis=-1)
return tokens_id.astype(np.int64), np.array([context_length], dtype=np.int64)
@staticmethod
def sync_type_info(sess_id: int) -> int:
input_info = np.array([sess_id], dtype=np.int64)
input_info_tensor = torch.tensor(input_info, dtype=torch.int64, device='cuda')
torch.distributed.broadcast(
input_info_tensor,
0,
)
sess_id = input_info_tensor[0].item()
return sess_id
@staticmethod
def sync_obj_info(model_dir: str) -> str:
tmp_list = [model_dir]
torch.distributed.broadcast_object_list(
tmp_list,
0,
)
return tmp_list[0]
def create_predictor(self):
model_dir = self.args.model_dir
assert self.args.num_attention_heads % self.args.tensor_model_parallel_size == 0
assert self.args.hidden_size % self.args.num_attention_heads == 0
model = get_model_for_infer(self.model_provider)
print_rank_0("Loading state dict ...")
_ = self.load_checkpoint(model, model_dir)
assert len(model) == 1, "Above condition should have caught this"
model = model[0]
model.eval()
if self.args.bf16 or self.args.fp16 :
print_rank_0(f" > converting model to {'bf16' if self.args.bf16 else 'fp16'} ...")
model.to(self.dtype)
print_rank_0(f" > moving model to GPU ...")
model.cuda(torch.cuda.current_device())
return model
def create_tokenizer(self):
assert os.path.exists(os.path.join(self.args.model_dir, "tokenizer.model"))
tokenizer = Tokenizer(model_path=os.path.join(self.args.model_dir, "tokenizer.model"))
return tokenizer
def load_checkpoint(self, model: List[LLaMAModel], path):
assert isinstance(model, list)
if not (path is not None and os.path.exists(path)):
raise ValueError
iteration = 0
if self.args.tensor_model_parallel_size == 1 and self.args.rank < self.args.tensor_model_parallel_size:
checkpoint_name = os.path.join(path, f"{self.args.model_name}.pt")
assert os.path.isfile(checkpoint_name)
elif self.args.rank < self.args.tensor_model_parallel_size:
checkpoints = sorted(pathlib.Path(path).glob(f"{self.args.model_name}_states_*.pt"))
assert len(checkpoints) == self.args.tensor_model_parallel_size
checkpoint_name = checkpoints[self.args.rank]
else:
raise ValueError
# Load the checkpoint.
print(f"rank_{self.args.rank} load: {checkpoint_name}", flush=True, file=sys.stderr)
state_dict = torch.load(checkpoint_name, map_location="cpu")
# Set iteration.
iteration = state_dict.get("iteration", 0)
if "model" in state_dict:
state_dict = state_dict["model"]
if "module" in state_dict:
state_dict = state_dict["module"]
# Model.
model[0].load_state_dict(state_dict, strict=True)
print_rank_0(
f"successfully loaded checkpoint from {path} "
f"at iteration {iteration}"
)
return iteration
def predict_batch(self, data):
common_len = int(data[1].item())
with torch.no_grad():
tokens_ids = data[0].clone().detach().cuda()
logits = self.predictor(
tokens=tokens_ids, # shape: [bsz, 1024]
start_pos=common_len,
)
logits = logits[:, -1].view(1, -1).contiguous()
probs = torch.softmax(logits, dim=-1).cpu().numpy()
return [np.squeeze(probs)]
def predict(self, token_ids: List[int], common_len: int) -> Tuple[List[int], List[float]]:
if torch.distributed.is_initialized():
torch.distributed.barrier()
try:
common_len_nda = np.array([common_len]).astype("int64")
token_ids_nda = np.array([token_ids], dtype=np.int64)
max_pad_len = max(token_ids_nda.shape[-1], 128)
max_pad_len = self.sync_type_info(max_pad_len)
token_ids_nda, tokens_id_len = self.pad_batch(token_ids_nda, max_seq_len=max_pad_len)
context_tensor = torch.tensor(token_ids_nda, dtype=torch.int64, device='cuda')
context_tensor_length = torch.tensor(tokens_id_len, dtype=torch.int64, device='cuda')
context_common_len = torch.tensor(common_len_nda, dtype=torch.int64, device='cuda')
torch.distributed.broadcast(
context_tensor,
0,
)
torch.distributed.broadcast(
context_tensor_length,
0,
)
torch.distributed.broadcast(
context_common_len,
0,
)
tokens_id_len = context_tensor_length.min().item()
batch = [context_tensor[:, :tokens_id_len], context_common_len]
out = self.predict_batch(batch)
# shape: [bsz, vocab_size] => [vocab_size]
out = out[0]
predict_id = np.argmax(out)
return [int(predict_id)], [out[predict_id]]
except Exception as e:
traceback.print_exc(file=sys.stderr)
raise RuntimeError(e)
class TestInference:
def __init__(self) -> None:
aix_config = {
"num_layers": 32, "hidden_size": 4096, "num_attention_heads": 32,
"max_position_embeddings": 32768, "fp16": False, "bf16": True,
"rope_theta": 256000, "inner_hidden_dim": 14464, "padded_vocab_size": 49152,
"seq_length": 4096, "micro_batch_size": 1, "use_flash_attn": True,
"use_cpu_initialization": True, "attention_head_type": "groupedquery"
}
initialize_megatron(
extra_args_provider=add_code_generation_args,
aix_config=aix_config
)
args = get_args()
self.sess = Predictor(args=args)
self.end_token_set = self.sess.tokenizer.end_token_set
def run_infer(self, code_string: str, max_new_tokens: int = 256, later_code: str = "", file_path: str = "") -> None:
tokens = self.sess.tokenizer.encode(
code_string=code_string, later_code=later_code, file_path=file_path
)
if len(tokens) == 0:
return self.sess.sync_obj_info("")
predict_list = []
common_len = 0
while True:
if torch.distributed.get_rank() == 0:
output_vals = self.sess.predict(
np.array([tokens], dtype='int32'),
np.array([common_len], dtype='int32')
)
predict_list.append(output_vals[0][0])
if len(predict_list) >= max_new_tokens or predict_list[-1] in self.end_token_set:
terminate_runs = 1
else:
terminate_runs = 0
common_len += len(tokens)
tokens = predict_list[-1:]
else:
tokens = [0] * 4
output_vals = self.sess.predict([], [], input_vals=[
np.array([tokens], dtype='int32'),
np.array([0], dtype='int32')
])
predict_list.append(0)
terminate_runs = 0
if self.sess.sync_type_info(terminate_runs) > 0:
break
return self.sess.sync_obj_info(self.sess.tokenizer.decode(predict_list))
if __name__ == "__main__":
infer = TestInference()
res = infer.run_infer(
code_string="""# 快速排序算法""",
later_code="\n",
file_path="test.py"
)
print(res)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment