Unverified Commit 013a4bed authored by Jianghai's avatar Jianghai Committed by GitHub
Browse files

[inference]fix import bug and delete down useless init (#4830)

* fix import bug and release useless init

* fix

* fix

* fix
parent 573f2705
import _utils
from .bloom import BloomInferenceForwards from .bloom import BloomInferenceForwards
from .chatglm2 import ChatGLM2InferenceForwards from .chatglm2 import ChatGLM2InferenceForwards
from .llama import LlamaInferenceForwards from .llama import LlamaInferenceForwards
......
""" """
Utils for model inference Utils for model inference
""" """
import os
import torch
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
def copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): def copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
"""
This function copies the key and value cache to the memory cache
Args:
layer_id : id of current layer
key_buffer : key cache
value_buffer : value cache
context_mem_index : index of memory cache in kv cache manager
mem_manager : cache manager
"""
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
return
def init_to_get_rotary(self, base=10000, use_elem=False):
"""
This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer
Args:
self : Model that holds the rotary positional embedding
base : calculation arg
use_elem : activated when using chatglm-based models
"""
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
if not hasattr(self.config, "rope_scaling"):
rope_scaling_factor = 1.0
else:
rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
if hasattr(self.config, "max_sequence_length"):
max_seq_len = self.config.max_sequence_length
elif hasattr(self.config, "max_position_embeddings"):
max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
else:
max_seq_len = 2048 * rope_scaling_factor
base = float(base)
# NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", None))
if ntk_alpha is not None:
ntk_alpha = float(ntk_alpha)
assert ntk_alpha >= 1, "NTK alpha must be greater than or equal to 1"
if ntk_alpha > 1:
print(f"Note: NTK enabled, alpha set to {ntk_alpha}")
max_seq_len *= ntk_alpha
base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula
n_elem = self.config.head_dim_
if use_elem:
n_elem //= 2
inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem))
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
...@@ -5,12 +5,9 @@ from transformers.modeling_outputs import BaseModelOutputWithPast ...@@ -5,12 +5,9 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton import ( from colossalai.kernel.triton import llama_context_attn_fwd, rotary_embedding_fwd, token_attention_fwd
copy_kv_cache_to_dest,
llama_context_attn_fwd, from ._utils import copy_kv_to_mem_cache
rotary_embedding_fwd,
token_attention_fwd,
)
try: try:
from vllm import layernorm_ops, pos_encoding_ops from vllm import layernorm_ops, pos_encoding_ops
...@@ -46,12 +43,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): ...@@ -46,12 +43,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
return q_embed, k_embed return q_embed, k_embed
def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
return
class LlamaInferenceForwards: class LlamaInferenceForwards:
""" """
This class holds forwards for llama inference. This class holds forwards for llama inference.
...@@ -285,11 +276,6 @@ class LlamaInferenceForwards: ...@@ -285,11 +276,6 @@ class LlamaInferenceForwards:
rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin) rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin)
def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
return
query_states = query_states.reshape(-1, self.num_heads, self.head_dim) query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
key_states = key_states.reshape(-1, self.num_heads, self.head_dim) key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
value_states = value_states.reshape(-1, self.num_heads, self.head_dim) value_states = value_states.reshape(-1, self.num_heads, self.head_dim)
...@@ -298,7 +284,7 @@ class LlamaInferenceForwards: ...@@ -298,7 +284,7 @@ class LlamaInferenceForwards:
# first token generation # first token generation
# copy key and value calculated in current step to memory manager # copy key and value calculated in current step to memory manager
_copy_kv_to_mem_cache( copy_kv_to_mem_cache(
infer_state.decode_layer_id, infer_state.decode_layer_id,
key_states, key_states,
value_states, value_states,
...@@ -331,7 +317,7 @@ class LlamaInferenceForwards: ...@@ -331,7 +317,7 @@ class LlamaInferenceForwards:
else: else:
# if decode is not contiguous, use triton kernel to copy key and value cache # if decode is not contiguous, use triton kernel to copy key and value cache
# k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
_copy_kv_to_mem_cache( copy_kv_to_mem_cache(
infer_state.decode_layer_id, infer_state.decode_layer_id,
key_states, key_states,
value_states, value_states,
......
from functools import partial from functools import partial
import torch
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
ChatGLMForConditionalGeneration, ChatGLMForConditionalGeneration,
ChatGLMModel, ChatGLMModel,
...@@ -9,13 +7,14 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( ...@@ -9,13 +7,14 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
GLMTransformer, GLMTransformer,
SelfAttention, SelfAttention,
) )
# import colossalai # import colossalai
from colossalai.shardformer.policies.chatglm2 import ChatGLMModelPolicy from colossalai.shardformer.policies.chatglm2 import ChatGLMModelPolicy
from ..modeling.chatglm2 import ChatGLM2InferenceForwards, _init_to_get_rotary from ..modeling._utils import init_to_get_rotary
from ..modeling.chatglm2 import ChatGLM2InferenceForwards
try: try:
from colossalai.kernel.triton.rms_norm import rmsnorm_forward
HAS_TRITON_RMSNORM = True HAS_TRITON_RMSNORM = True
except: except:
print("you should install triton from https://github.com/openai/triton") print("you should install triton from https://github.com/openai/triton")
...@@ -23,7 +22,6 @@ except: ...@@ -23,7 +22,6 @@ except:
class ChatGLM2InferPolicy(ChatGLMModelPolicy): class ChatGLM2InferPolicy(ChatGLMModelPolicy):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
...@@ -32,45 +30,44 @@ class ChatGLM2InferPolicy(ChatGLMModelPolicy): ...@@ -32,45 +30,44 @@ class ChatGLM2InferPolicy(ChatGLMModelPolicy):
self.shard_config._infer() self.shard_config._infer()
model_infer_forward = ChatGLM2InferenceForwards.chatglm_model_forward model_infer_forward = ChatGLM2InferenceForwards.chatglm_model_forward
method_replacement = {'forward': model_infer_forward} method_replacement = {"forward": model_infer_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=ChatGLMModel) self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=ChatGLMModel)
encoder_infer_forward = ChatGLM2InferenceForwards.chatglm_encoder_forward encoder_infer_forward = ChatGLM2InferenceForwards.chatglm_encoder_forward
method_replacement = {'forward': encoder_infer_forward} method_replacement = {"forward": encoder_infer_forward}
self.append_or_create_method_replacement(description=method_replacement, self.append_or_create_method_replacement(
policy=policy, description=method_replacement, policy=policy, target_key=GLMTransformer
target_key=GLMTransformer) )
encoder_layer_infer_forward = ChatGLM2InferenceForwards.chatglm_glmblock_forward encoder_layer_infer_forward = ChatGLM2InferenceForwards.chatglm_glmblock_forward
method_replacement = {'forward': encoder_layer_infer_forward} method_replacement = {"forward": encoder_layer_infer_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=GLMBlock) self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=GLMBlock)
attn_infer_forward = ChatGLM2InferenceForwards.chatglm_flash_attn_kvcache_forward attn_infer_forward = ChatGLM2InferenceForwards.chatglm_flash_attn_kvcache_forward
method_replacement = {'forward': attn_infer_forward} method_replacement = {"forward": attn_infer_forward}
self.append_or_create_method_replacement(description=method_replacement, self.append_or_create_method_replacement(
policy=policy, description=method_replacement, policy=policy, target_key=SelfAttention
target_key=SelfAttention) )
# for rmsnorm and others, we need to check the shape # for rmsnorm and others, we need to check the shape
return policy return policy
def postprocess(self): def postprocess(self):
_init_to_get_rotary(self.model) init_to_get_rotary(self.model)
return self.model return self.model
class ChatGLM2ForConditionalGenerationInferPolicy(ChatGLM2InferPolicy): class ChatGLM2ForConditionalGenerationInferPolicy(ChatGLM2InferPolicy):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def module_policy(self): def module_policy(self):
policy = super().module_policy() policy = super().module_policy()
model_infer_forward = ChatGLM2InferenceForwards.chatglm_for_conditional_generation_forward model_infer_forward = ChatGLM2InferenceForwards.chatglm_for_conditional_generation_forward
method_replacement = {'forward': partial(model_infer_forward)} method_replacement = {"forward": partial(model_infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, self.append_or_create_method_replacement(
policy=policy, description=method_replacement, policy=policy, target_key=ChatGLMForConditionalGeneration
target_key=ChatGLMForConditionalGeneration) )
return policy return policy
def postprocess(self): def postprocess(self):
......
...@@ -3,11 +3,12 @@ from functools import partial ...@@ -3,11 +3,12 @@ from functools import partial
import torch import torch
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
from colossalai.shardformer.layer import VocabParallelEmbedding1D from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
# import colossalai # import colossalai
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
from ..modeling._utils import init_to_get_rotary
from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward
try: try:
...@@ -50,38 +51,38 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy): ...@@ -50,38 +51,38 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.q_proj", suffix="self_attn.q_proj",
target_module=ColCaiQuantLinear, target_module=ColCaiQuantLinear,
kwargs={'split_num': 1}, kwargs={"split_num": 1},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.k_proj", suffix="self_attn.k_proj",
target_module=ColCaiQuantLinear, target_module=ColCaiQuantLinear,
kwargs={'split_num': 1}, kwargs={"split_num": 1},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.v_proj", suffix="self_attn.v_proj",
target_module=ColCaiQuantLinear, target_module=ColCaiQuantLinear,
kwargs={'split_num': 1}, kwargs={"split_num": 1},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.o_proj", suffix="self_attn.o_proj",
target_module=RowCaiQuantLinear, target_module=RowCaiQuantLinear,
kwargs={'split_num': 1}, kwargs={"split_num": 1},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.gate_proj", suffix="mlp.gate_proj",
target_module=ColCaiQuantLinear, target_module=ColCaiQuantLinear,
kwargs={'split_num': 1}, kwargs={"split_num": 1},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.up_proj", suffix="mlp.up_proj",
target_module=ColCaiQuantLinear, target_module=ColCaiQuantLinear,
kwargs={'split_num': 1}, kwargs={"split_num": 1},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.down_proj", suffix="mlp.down_proj",
target_module=RowCaiQuantLinear, target_module=RowCaiQuantLinear,
kwargs={'split_num': 1}, kwargs={"split_num": 1},
) ),
], ],
) )
...@@ -117,3 +118,7 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy): ...@@ -117,3 +118,7 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
) )
return policy return policy
def postprocess(self):
init_to_get_rotary(self.model.model)
return self.model
...@@ -3,6 +3,12 @@ try: ...@@ -3,6 +3,12 @@ try:
HAS_TRITON = True HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("Triton is not installed. Please install Triton to use Triton kernels.")
# There may exist import error even if we have triton installed.
if HAS_TRITON:
from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
from .copy_kv_cache_dest import copy_kv_cache_to_dest from .copy_kv_cache_dest import copy_kv_cache_to_dest
from .fused_layernorm import layer_norm from .fused_layernorm import layer_norm
...@@ -23,7 +29,3 @@ try: ...@@ -23,7 +29,3 @@ try:
"token_attention_fwd", "token_attention_fwd",
"gptq_fused_linear_triton", "gptq_fused_linear_triton",
] ]
except ImportError:
HAS_TRITON = False
print("Triton is not installed. Please install Triton to use Triton kernels.")
...@@ -15,30 +15,6 @@ from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_us ...@@ -15,30 +15,6 @@ from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_us
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
def init_to_get_rotary(self, base=10000):
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
if not hasattr(self.config, "rope_scaling"):
rope_scaling_factor = 1.0
else:
rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
if hasattr(self.config, "max_sequence_length"):
max_seq_len = self.config.max_sequence_length
elif hasattr(self.config, "max_position_embeddings"):
max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
else:
max_seq_len = 2048 * rope_scaling_factor
base = float(base)
inv_freq = 1.0 / (
base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_)
)
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
return
def print_perf_stats(latency_set, config, bs, warmup=3): def print_perf_stats(latency_set, config, bs, warmup=3):
# trim warmup queries # trim warmup queries
latency_set = list(latency_set) latency_set = list(latency_set)
...@@ -66,7 +42,6 @@ def run_llama_test(args): ...@@ -66,7 +42,6 @@ def run_llama_test(args):
tokenizer = LlamaTokenizer.from_pretrained(llama_model_path) tokenizer = LlamaTokenizer.from_pretrained(llama_model_path)
tokenizer.pad_token_id = tokenizer.unk_token_id tokenizer.pad_token_id = tokenizer.unk_token_id
model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id)
init_to_get_rotary(model.model, base=10000)
model = model.half() model = model.half()
model_config = model.config model_config = model.config
......
import argparse import argparse
import logging
import os import os
import time import time
import torch import torch
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig from auto_gptq import AutoGPTQForCausalLM
from auto_gptq.nn_modules.qlinear import GeneralQuantLinear from transformers import LlamaTokenizer
from torch import distributed as dist
from torch.profiler import ProfilerActivity, profile, record_function
from transformers import AutoTokenizer, LlamaForCausalLM, LlamaTokenizer, TextGenerationPipeline
import colossalai import colossalai
from colossalai.gptq import CaiQuantLinear
from colossalai.gptq.gptq_tp import replace_autogptq_linear
from colossalai.inference.tensor_parallel.engine import TPInferEngine from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.inference.tensor_parallel.modeling._utils import init_to_get_rotary
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
def init_to_get_rotary(self, base=10000):
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
if not hasattr(self.config, "rope_scaling"):
rope_scaling_factor = 1.0
else:
rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
if hasattr(self.config, "max_sequence_length"):
max_seq_len = self.config.max_sequence_length
elif hasattr(self.config, "max_position_embeddings"):
max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
else:
max_seq_len = 2048 * rope_scaling_factor
base = float(base)
inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) /
self.config.head_dim_))
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
return
def print_perf_stats(latency_set, config, bs, warmup=3): def print_perf_stats(latency_set, config, bs, warmup=3):
...@@ -74,23 +46,23 @@ def run_llama_test(args): ...@@ -74,23 +46,23 @@ def run_llama_test(args):
tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.pad_token_id = tokenizer.eos_token_id
# load quantized model to the first GPU # load quantized model to the first GPU
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, model = AutoGPTQForCausalLM.from_quantized(
device=torch.cuda.current_device(), quantized_model_dir, device=torch.cuda.current_device(), inject_fused_attention=False
inject_fused_attention=False) )
init_to_get_rotary(model.model.model, base=10000) init_to_get_rotary(model.model.model, base=10000)
model_config = model.config model_config = model.config
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, shard_config = ShardConfig(
inference_only=True, enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True
inference_gptq=True) )
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
input_tokens = { input_tokens = {
"input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'), "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"),
"attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda') "attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"),
} }
iters = 10 iters = 10
...@@ -111,7 +83,7 @@ def run_llama_test(args): ...@@ -111,7 +83,7 @@ def run_llama_test(args):
def check_llama(rank, world_size, port, args): def check_llama(rank, world_size, port, args):
disable_existing_loggers() disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_llama_test(args) run_llama_test(args)
...@@ -123,12 +95,12 @@ def test_llama(args): ...@@ -123,12 +95,12 @@ def test_llama(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-p', '--path', type=str, help='Model path', required=True) parser.add_argument("-p", "--path", type=str, help="Model path", required=True)
parser.add_argument('-q', '--quantized_path', type=str, help='Model path', required=True) parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True)
parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size') parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size') parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size")
parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length') parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length")
parser.add_argument('--output_len', type=int, default=128, help='Maximum output length') parser.add_argument("--output_len", type=int, default=128, help="Maximum output length")
args = parser.parse_args() args = parser.parse_args()
......
...@@ -20,30 +20,6 @@ MAX_OUTPUT_LEN = 100 ...@@ -20,30 +20,6 @@ MAX_OUTPUT_LEN = 100
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
def init_to_get_rotary(self, base=10000):
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
if not hasattr(self.config, "rope_scaling"):
rope_scaling_factor = 1.0
else:
rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
if hasattr(self.config, "max_sequence_length"):
max_seq_len = self.config.max_sequence_length
elif hasattr(self.config, "max_position_embeddings"):
max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
else:
max_seq_len = 2048 * rope_scaling_factor
base = float(base)
inv_freq = 1.0 / (
base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_)
)
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
return
@parameterize( @parameterize(
"test_config", "test_config",
[ [
...@@ -56,7 +32,6 @@ def run_llama_test(test_config): ...@@ -56,7 +32,6 @@ def run_llama_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama_for_casual_lm") sub_model_zoo = model_zoo.get_sub_registry("transformers_llama_for_casual_lm")
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
orig_model = model_fn() orig_model = model_fn()
init_to_get_rotary(orig_model.model, base=10000)
orig_model = orig_model.half() orig_model = orig_model.half()
data = data_gen_fn() data = data_gen_fn()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment