Unverified Commit ce7ade38 authored by Jianghai's avatar Jianghai Committed by GitHub
Browse files

[inference] chatglm2 infer demo (#4724)

* add chatglm2

* add

* gather needed kernels

* fix some bugs

* finish context forward

* finish context stage

* fix

* add

* pause

* add

* fix bugs

* finish chatglm

* fix bug

* change some logic

* fix bugs

* change some logics

* add

* add

* add

* fix

* fix tests

* fix
parent 946ab56c
...@@ -16,7 +16,13 @@ from .kvcache_manager import MemoryManager ...@@ -16,7 +16,13 @@ from .kvcache_manager import MemoryManager
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
_supported_models = ["LlamaForCausalLM", "LlamaModel", "BloomForCausalLM"] _supported_models = [
"LlamaForCausalLM",
"LlamaModel",
"BloomForCausalLM",
"ChatGLMModel",
"ChatGLMForConditionalGeneration",
]
class TPInferEngine: class TPInferEngine:
...@@ -64,7 +70,13 @@ class TPInferEngine: ...@@ -64,7 +70,13 @@ class TPInferEngine:
self.head_dim = model.config.hidden_size // model.config.num_attention_heads self.head_dim = model.config.hidden_size // model.config.num_attention_heads
self.head_num = model.config.num_attention_heads self.head_num = model.config.num_attention_heads
self.layer_num = model.config.num_hidden_layers num_hidden_layers = (
model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers
)
self.layer_num = num_hidden_layers
self.multi_query_group_num = (
model.config.multi_query_group_num if hasattr(model.config, "multi_query_group_num") else 0
)
self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
self.cache_manager = None self.cache_manager = None
...@@ -85,6 +97,19 @@ class TPInferEngine: ...@@ -85,6 +97,19 @@ class TPInferEngine:
assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig" assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig"
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}" assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
self.head_num //= self.tp_size # update sharded number of heads self.head_num //= self.tp_size # update sharded number of heads
if self.multi_query_group_num:
# NOTE the logic of MQA tensor parallelism should be specified.
assert (
self.multi_query_group_num % self.tp_size == 0
), f"Cannot shard {self.multi_query_group_num} query groups with tp size {self.tp_size}"
self.cache_manager = MemoryManager(
self.max_total_token_num,
self.dtype,
self.multi_query_group_num // self.tp_size,
self.head_dim,
self.layer_num,
)
else:
self.cache_manager = MemoryManager( self.cache_manager = MemoryManager(
self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num
) )
......
import _utils
from .bloom import BloomInferenceForwards from .bloom import BloomInferenceForwards
from .chatglm2 import ChatGLM2InferenceForwards
from .llama import LlamaInferenceForwards from .llama import LlamaInferenceForwards
__all__ = ["BloomInferenceForwards", "LlamaInferenceForwards"] __all__ = ["BloomInferenceForwards", "LlamaInferenceForwards", "ChatGLM2InferenceForwards"]
"""
Utils for model inference
"""
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
def copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
return
This diff is collapsed.
...@@ -100,7 +100,7 @@ class LlamaInferenceForwards: ...@@ -100,7 +100,7 @@ class LlamaInferenceForwards:
# NOTE: differentiate with prefill stage # NOTE: differentiate with prefill stage
# block_loc require different value-assigning method for two different stage # block_loc require different value-assigning method for two different stage
if use_cache and seq_length != 1: if use_cache and seq_length != 1:
# NOTE assuem prefill stage # NOTE assume prefill stage
# allocate memory block # allocate memory block
infer_state.is_context_stage = True # set prefill stage, notify attention layer infer_state.is_context_stage = True # set prefill stage, notify attention layer
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
......
from .bloom import BloomModelInferPolicy from .bloom import BloomModelInferPolicy
from .chatglm2 import ChatGLM2InferPolicy
from .llama import LlamaModelInferPolicy from .llama import LlamaModelInferPolicy
__all__ = ["BloomModelInferPolicy", "LlamaModelInferPolicy"] __all__ = ["BloomModelInferPolicy", "LlamaModelInferPolicy", "ChatGLM2InferPolicy"]
from functools import partial
import torch
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
ChatGLMForConditionalGeneration,
ChatGLMModel,
GLMBlock,
GLMTransformer,
SelfAttention,
)
# import colossalai
from colossalai.shardformer.policies.chatglm2 import ChatGLMModelPolicy
from ..modeling.chatglm2 import ChatGLM2InferenceForwards, _init_to_get_rotary
try:
from colossalai.kernel.triton.rms_norm import rmsnorm_forward
HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
HAS_TRITON_RMSNORM = False
class ChatGLM2InferPolicy(ChatGLMModelPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
self.shard_config._infer()
model_infer_forward = ChatGLM2InferenceForwards.chatglm_model_forward
method_replacement = {'forward': model_infer_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=ChatGLMModel)
encoder_infer_forward = ChatGLM2InferenceForwards.chatglm_encoder_forward
method_replacement = {'forward': encoder_infer_forward}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=GLMTransformer)
encoder_layer_infer_forward = ChatGLM2InferenceForwards.chatglm_glmblock_forward
method_replacement = {'forward': encoder_layer_infer_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=GLMBlock)
attn_infer_forward = ChatGLM2InferenceForwards.chatglm_flash_attn_kvcache_forward
method_replacement = {'forward': attn_infer_forward}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=SelfAttention)
# for rmsnorm and others, we need to check the shape
return policy
def postprocess(self):
_init_to_get_rotary(self.model)
return self.model
class ChatGLM2ForConditionalGenerationInferPolicy(ChatGLM2InferPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
model_infer_forward = ChatGLM2InferenceForwards.chatglm_for_conditional_generation_forward
method_replacement = {'forward': partial(model_infer_forward)}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=ChatGLMForConditionalGeneration)
return policy
def postprocess(self):
return super().postprocess()
...@@ -11,7 +11,6 @@ except ImportError: ...@@ -11,7 +11,6 @@ except ImportError:
HAS_TRITON = False HAS_TRITON = False
print("please install triton from https://github.com/openai/triton") print("please install triton from https://github.com/openai/triton")
if HAS_TRITON: if HAS_TRITON:
""" """
this function is modified from this function is modified from
...@@ -240,3 +239,328 @@ if HAS_TRITON: ...@@ -240,3 +239,328 @@ if HAS_TRITON:
num_stages=1, num_stages=1,
) )
return return
@triton.jit
def _fwd_kernel_latest(
Q,
K,
V,
sm_scale,
B_Start_Loc,
B_Seqlen,
Out,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
kv_group_num,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_kv_head = cur_head // kv_group_num
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
+ cur_head * stride_qh
+ offs_d[None, :] * stride_qd
)
off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd
off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd
q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
other=0.0,
)
# mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
other=0.0,
)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
+ cur_head * stride_oh
+ offs_d[None, :] * stride_od
)
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
return
@triton.jit
def _fwd_kernel_old(
Q,
K,
V,
sm_scale,
B_Start_Loc,
B_Seqlen,
TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
Out,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_tmp_b,
stride_tmp_h,
stride_tmp_s,
kv_group_num,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_kv_head = cur_head // kv_group_num
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
+ cur_head * stride_qh
+ offs_d[None, :] * stride_qd
)
off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd
off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd
q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
k_ptrs = K + off_k
v_ptrs = V + off_v
t_ptrs = TMP + cur_batch * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s
# t_ptrs = TMP + offs_m
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
other=0.0,
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
tl.store(t_ptrs, acc_scale)
acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
other=0.0,
)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
+ cur_head * stride_oh
+ offs_d[None, :] * stride_od
)
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
return
@torch.no_grad()
def llama2_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
if triton.__version__ >= "2.1.0":
BLOCK = 128
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
sm_scale = 1.0 / (Lq**0.5) # 计算scale系数
batch, head = b_seq_len.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k.shape[1]
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
num_warps = 4 if Lk <= 64 else 8
_fwd_kernel_latest[grid](
q,
k,
v,
sm_scale,
b_start_loc,
b_seq_len,
o,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
kv_group_num=kv_group_num,
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return
elif triton.__version__ == "2.0.0":
BLOCK = 128
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
sm_scale = 1.0 / (Lq**0.5)
batch, head = b_seq_len.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k.shape[1]
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
# num_warps = 4
_fwd_kernel_old[grid](
q,
k,
v,
sm_scale,
b_start_loc,
b_seq_len,
tmp,
o,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
tmp.stride(0),
tmp.stride(1),
tmp.stride(2),
kv_group_num=kv_group_num,
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return
...@@ -105,3 +105,108 @@ def rotary_embedding_fwd(q, cos, sin): ...@@ -105,3 +105,108 @@ def rotary_embedding_fwd(q, cos, sin):
num_stages=1, num_stages=1,
) )
return return
class Llama2Forwards:
@staticmethod
@triton.jit
def _rotary_kernel(
Q,
Cos,
Sin,
stride_qbs,
stride_qh,
stride_qd,
stride_cosbs,
stride_cosd,
stride_sinbs,
stride_sind,
max_total_len,
H, # N_CTX
BLOCK_HEAD: tl.constexpr,
BLOCK_SEQ: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
):
cur_head_index = tl.program_id(0)
cur_seq_index = tl.program_id(1)
cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)
dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) * 2
dim_range1 = dim_range0 + 1
off_q0 = (
cur_seq_range[:, None, None] * stride_qbs
+ cur_head_range[None, :, None] * stride_qh
+ dim_range0[None, None, :] * stride_qd
)
off_q1 = (
cur_seq_range[:, None, None] * stride_qbs
+ cur_head_range[None, :, None] * stride_qh
+ dim_range1[None, None, :] * stride_qd
)
cos_range = tl.arange(0, BLOCK_DMODEL // 2)
off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd
q0 = tl.load(
Q + off_q0,
mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H),
other=0.0,
)
q1 = tl.load(
Q + off_q1,
mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H),
other=0.0,
)
cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)
sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)
out0 = q0 * cos - q1 * sin
out1 = q0 * sin + q1 * cos
tl.store(
Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H)
)
tl.store(
Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H)
)
return
@staticmethod
@torch.no_grad()
def rotary_emb_fwd(q, cos, sin):
total_len = q.shape[0]
head_num = q.shape[1]
head_dim = q.shape[2] // 2
assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
BLOCK_HEAD = 4
BLOCK_SEQ = 32
grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ))
if head_dim >= 128:
num_warps = 8
else:
num_warps = 4
Llama2Forwards._rotary_kernel[grid](
q,
cos,
sin,
q.stride(0),
q.stride(1),
q.stride(2),
cos.stride(0),
cos.stride(1),
sin.stride(0),
sin.stride(1),
total_len,
head_num,
BLOCK_HEAD=BLOCK_HEAD,
BLOCK_SEQ=BLOCK_SEQ,
BLOCK_DMODEL=head_dim,
num_warps=num_warps,
num_stages=1,
)
return
...@@ -402,3 +402,440 @@ if HAS_TRITON: ...@@ -402,3 +402,440 @@ if HAS_TRITON:
prob = None prob = None
return return
class Llama2TokenAttentionForwards:
@staticmethod
@triton.jit
def _fwd_kernel(
Logics,
V,
Out,
B_Loc,
B_Start_Loc,
B_Seqlen,
max_input_len,
stride_logic_h,
stride_logic_bs,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_b_loc_b,
stride_b_loc_s,
other_kv_index, # avoid nan information
kv_group_num,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
cur_kv_head = cur_head // kv_group_num
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_v = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd
off_b_loc = cur_batch * stride_b_loc_b + (max_input_len - cur_batch_seq_len) * stride_b_loc_s
v_ptrs = V + off_v
e_max = float("-inf")
e_sum = 0.0
acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)
for start_n in range(0, cur_batch_seq_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
v_index = tl.load(
B_Loc + off_b_loc + (start_n + offs_n) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_seq_len,
other=other_kv_index,
)
qk = tl.load(
Logics + cur_head * stride_logic_h + (cur_batch_start_loc + start_n + offs_n) * stride_logic_bs,
mask=start_n + offs_n < cur_batch_seq_len,
other=float("-inf"),
)
n_e_max = tl.maximum(tl.max(qk, 0), e_max)
old_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max)
e_sum = e_sum * old_scale + tl.sum(p, 0)
v = tl.load(v_ptrs + v_index[:, None] * stride_vbs)
acc = acc * old_scale + tl.sum(p[:, None] * v, 0)
e_max = n_e_max
acc = acc / e_sum
off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
return
@staticmethod
@torch.no_grad()
def token_softmax_reducev_fwd(logics, v, o, b_loc, b_start_loc, b_seq_len, max_input_len, other_kv_index):
BLOCK = 64
batch, head = b_seq_len.shape[0], logics.shape[0]
grid = (batch, head)
kv_group_num = logics.shape[0] // v.shape[1]
num_warps = 1
Llama2TokenAttentionForwards._fwd_kernel[grid](
logics,
v,
o,
b_loc,
b_start_loc,
b_seq_len,
max_input_len,
logics.stride(0),
logics.stride(1),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
b_loc.stride(0),
b_loc.stride(1),
other_kv_index,
kv_group_num,
BLOCK_DMODEL=v.shape[-1],
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=3,
)
return
@staticmethod
@triton.jit
def _fwd_kernel_token_softmax(
Logics,
B_Start_Loc,
B_Seqlen,
Prob_Out,
stride_logic_h,
stride_logic_bs,
stride_prob_h,
stride_prob_bs,
BLOCK_SIZE: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
col_offsets = tl.arange(0, BLOCK_SIZE)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
row = tl.load(
Logics + cur_head * stride_logic_h + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs,
mask=col_offsets < cur_batch_seq_len,
other=-float("inf"),
).to(tl.float32)
row_minus_max = row - tl.max(row, axis=0)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
tl.store(
Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets) * stride_prob_bs,
softmax_output,
mask=col_offsets < cur_batch_seq_len,
)
return
@staticmethod
@torch.no_grad()
def token_softmax_fwd(Logics, B_Start_Loc, B_Seqlen, Prob_Out, max_input_len):
BLOCK_SIZE = triton.next_power_of_2(max_input_len)
batch, head_num = B_Start_Loc.shape[0], Logics.shape[0]
num_warps = 4
if BLOCK_SIZE >= 2048:
num_warps = 8
if BLOCK_SIZE >= 4096:
num_warps = 16
Llama2TokenAttentionForwards._fwd_kernel_token_softmax[(batch, head_num)](
Logics,
B_Start_Loc,
B_Seqlen,
Prob_Out,
Logics.stride(0),
Logics.stride(1),
Prob_Out.stride(0),
Prob_Out.stride(1),
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
)
return
@staticmethod
@triton.jit
def _fwd_kernel_token_att1(
Q,
K,
sm_scale,
B_Loc,
B_Start_Loc,
B_Seqlen,
max_input_len,
Att_Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
att_stride_h,
att_stride_bs,
kv_group_num,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_n = tl.program_id(2)
cur_kv_head = cur_head // kv_group_num
offs_d = tl.arange(0, BLOCK_DMODEL)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
cur_batch_start_index = max_input_len - cur_batch_seq_len
cur_batch_end_index = max_input_len
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
block_stard_index = start_n * BLOCK_N
block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
for start_mark in range(0, block_mask, 1):
q = tl.load(Q + off_q + start_mark)
offs_n_new = cur_batch_start_index + offs_n
k_loc = tl.load(
B_Loc + stride_b_loc_b * cur_batch + stride_b_loc_s * offs_n_new,
mask=offs_n_new < cur_batch_end_index,
other=0,
)
off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd
k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0)
att_value = tl.sum(q[None, :] * k, 1)
att_value *= sm_scale
off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs
tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)
return
@staticmethod
@torch.no_grad()
def token_att_fwd(q, k, att_out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len):
BLOCK = 32
# shape constraints
Lq, Lk = q.shape[-1], k.shape[-1]
assert Lq == Lk
assert Lk in {16, 32, 64, 128}
sm_scale = 1.0 / (Lk**0.5)
batch, head_num = B_Loc.shape[0], q.shape[1]
grid = (batch, head_num, triton.cdiv(max_input_len, BLOCK))
kv_group_num = q.shape[1] // k.shape[1]
num_warps = 4 if Lk <= 64 else 8
num_warps = 2
Llama2TokenAttentionForwards._fwd_kernel_token_att1[grid](
q,
k,
sm_scale,
B_Loc,
B_Start_Loc,
B_Seqlen,
max_input_len,
att_out,
B_Loc.stride(0),
B_Loc.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
att_out.stride(0),
att_out.stride(1),
kv_group_num=kv_group_num,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return
@staticmethod
@triton.jit
def _fwd_kernel_token_att2(
Prob,
V,
Out,
B_Loc,
B_Start_Loc,
B_Seqlen,
max_input_len, # B_Start_Loc cumsum of input lens if continuous
stride_b_loc_b,
stride_b_loc_s,
stride_ph,
stride_pbs,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
kv_group_num,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
cur_kv_head = cur_head // kv_group_num
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_start_index = max_input_len - cur_batch_seq_len
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
v_loc_off = cur_batch * stride_b_loc_b + (cur_batch_start_index + offs_n) * stride_b_loc_s
p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs
v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd
acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)
for start_n in range(0, cur_batch_seq_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
p_value = tl.load(
Prob + p_offs + start_n * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0
)
v_loc = tl.load(
B_Loc + v_loc_off + start_n * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0
)
v_value = tl.load(
V + v_offs + v_loc[:, None] * stride_vbs,
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
other=0.0,
)
acc += tl.sum(p_value[:, None] * v_value, 0)
acc = acc.to(tl.float16)
off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
return
@staticmethod
@torch.no_grad()
def token_att_fwd2(prob, v, out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len):
if triton.__version__ >= "2.1.0":
BLOCK = 128
else:
BLOCK = 64
batch, head = B_Loc.shape[0], prob.shape[0]
grid = (batch, head)
num_warps = 4
dim = v.shape[-1]
kv_group_num = prob.shape[0] // v.shape[1]
Llama2TokenAttentionForwards._fwd_kernel_token_att2[grid](
prob,
v,
out,
B_Loc,
B_Start_Loc,
B_Seqlen,
max_input_len,
B_Loc.stride(0),
B_Loc.stride(1),
prob.stride(0),
prob.stride(1),
v.stride(0),
v.stride(1),
v.stride(2),
out.stride(0),
out.stride(1),
out.stride(2),
kv_group_num=kv_group_num,
BLOCK_DMODEL=dim,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return
# this is the interface of llama2 attn forward
@staticmethod
@torch.no_grad()
def token_attn(
q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, other_kv_index
):
total_token_num = k.shape[0]
batch_size, head_num, head_dim = q.shape
calcu_shape1 = (batch_size, head_num, head_dim)
att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda")
Llama2TokenAttentionForwards.token_att_fwd(
q,
k,
att_m_tensor,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seq_len,
max_len_in_batch,
)
if triton.__version__ == "2.0.0":
prob = torch.empty_like(att_m_tensor)
Llama2TokenAttentionForwards.token_softmax_fwd(
att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch
)
att_m_tensor = None
Llama2TokenAttentionForwards.token_att_fwd2(
prob,
v,
attn_out.view(calcu_shape1),
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seq_len,
max_len_in_batch,
)
prob = None
return
elif triton.__version__ >= "2.1.0":
Llama2TokenAttentionForwards.token_softmax_reducev_fwd(
att_m_tensor,
v,
attn_out.view(calcu_shape1),
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seq_len,
max_len_in_batch,
other_kv_index,
)
else:
raise Exception("not support triton version")
...@@ -380,12 +380,10 @@ class SelfAttention(torch.nn.Module): ...@@ -380,12 +380,10 @@ class SelfAttention(torch.nn.Module):
def __init__(self, config: ChatGLMConfig, layer_number, device=None): def __init__(self, config: ChatGLMConfig, layer_number, device=None):
super(SelfAttention, self).__init__() super(SelfAttention, self).__init__()
self.layer_number = max(1, layer_number) self.layer_number = max(1, layer_number)
self.projection_size = config.kv_channels * config.num_attention_heads self.projection_size = config.kv_channels * config.num_attention_heads
# Per attention head and per partition values. # Per attention head and per partition values.
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
self.num_attention_heads_per_partition = config.num_attention_heads self.num_attention_heads_per_partition = config.num_attention_heads
self.multi_query_attention = config.multi_query_attention self.multi_query_attention = config.multi_query_attention
self.qkv_hidden_size = 3 * self.projection_size self.qkv_hidden_size = 3 * self.projection_size
if self.multi_query_attention: if self.multi_query_attention:
...@@ -445,7 +443,6 @@ class SelfAttention(torch.nn.Module): ...@@ -445,7 +443,6 @@ class SelfAttention(torch.nn.Module):
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer = self.query_key_value(hidden_states) mixed_x_layer = self.query_key_value(hidden_states)
if self.multi_query_attention: if self.multi_query_attention:
(query_layer, key_layer, value_layer) = mixed_x_layer.split( (query_layer, key_layer, value_layer) = mixed_x_layer.split(
[ [
...@@ -541,7 +538,6 @@ class SelfAttention(torch.nn.Module): ...@@ -541,7 +538,6 @@ class SelfAttention(torch.nn.Module):
# ================= # =================
# Output. [sq, b, h] # Output. [sq, b, h]
# ================= # =================
output = self.dense(context_layer) output = self.dense(context_layer)
return output, kv_cache return output, kv_cache
......
...@@ -164,6 +164,13 @@ _INFER_POLICY_LIST = { ...@@ -164,6 +164,13 @@ _INFER_POLICY_LIST = {
"transformers.models.bloom.modeling_bloom.BloomForCausalLM": PolicyLocation( "transformers.models.bloom.modeling_bloom.BloomForCausalLM": PolicyLocation(
file_name="bloom", class_name="BloomModelInferPolicy" file_name="bloom", class_name="BloomModelInferPolicy"
), ),
# ChatGLM2
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation(
file_name="chatglm2", class_name="ChatGLM2InferPolicy"
),
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
file_name="chatglm2", class_name="ChatGLM2ForConditionalGenerationInferPolicy"
),
} }
...@@ -208,7 +215,7 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> ...@@ -208,7 +215,7 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) ->
if policy_location is None: if policy_location is None:
raise NotImplementedError( raise NotImplementedError(
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}" f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}"
) )
else: else:
policy = import_policy(policy_location, inference_only) policy = import_policy(policy_location, inference_only)
......
...@@ -39,6 +39,21 @@ config = ChatGLMConfig( ...@@ -39,6 +39,21 @@ config = ChatGLMConfig(
padded_vocab_size=65024, padded_vocab_size=65024,
hidden_size=64, hidden_size=64,
num_attention_heads=8, num_attention_heads=8,
kv_channels=16,
rmsnorm=True,
original_rope=True,
use_cache=True,
torch_dtype=torch.float32,
)
infer_config = ChatGLMConfig(
num_layers=2,
padded_vocab_size=65024,
hidden_size=128,
num_attention_heads=8,
multi_query_attention=True,
multi_query_group_num=2,
kv_channels=16,
rmsnorm=True, rmsnorm=True,
original_rope=True, original_rope=True,
use_cache=True, use_cache=True,
......
import os
import pytest
import torch
import torch.distributed as dist
from packaging import version
from transformers import AutoTokenizer
import colossalai
from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo.transformers.chatglm2 import infer_config
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
TPSIZE = 1
BATCH_SIZE = 8
MAX_INPUT_LEN = 12
MAX_OUTPUT_LEN = 100
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
@parameterize(
"test_config",
[
{
"tp_size": TPSIZE,
}
],
)
def run_chatglm2_test(test_config):
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
# pad_token_id = 0
model_fn = lambda: ChatGLMForConditionalGeneration(infer_config, empty_init=False)
orig_model = model_fn()
orig_model = orig_model.half()
text = ["how is the weather today?"]
input_ids = tokenizer.batch_encode_plus(text, return_tensors="pt", padding=True)
shard_config = ShardConfig(
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
)
infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
outputs = infer_engine.generate(input_ids, **generate_kwargs)
assert outputs is not None
# print("outputs.shape: ", outputs[0].shape)
# print("outputs: ", outputs[0])
if not dist.is_initialized() or dist.get_rank() == 0:
for o in outputs:
output_text = tokenizer.decode(o)
print(output_text)
def check_chatglm2(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_chatglm2_test()
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_chatglm2():
spawn(check_chatglm2, TPSIZE)
if __name__ == "__main__":
test_chatglm2()
import pytest
import torch
from packaging import version
try:
pass
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim):
xq = xq.view(bs, 1, num_head, head_dim)
xk = xk.view(bs, seqlen, num_head, head_dim)
xv = xv.view(bs, seqlen, num_head, head_dim)
logics = torch.sum(xq * xk, dim=3, keepdim=False) * 1 / (head_dim**0.5)
prob = torch.softmax(logics, dim=1)
prob = prob.view(bs, seqlen, num_head, 1)
return torch.sum(prob * xv, dim=1, keepdim=False)
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
def test():
Z, head_num, seq_len, head_dim = 2, 32, 2048, 128
dtype = torch.float16
# attn out: 2,4096
q = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
k = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2)
v = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2)
o = torch.empty_like()
# o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2)
max_kv_cache_len = seq_len
kv_cache_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda")
kv_cache_loc = torch.zeros((Z, seq_len), dtype=torch.int32, device="cuda")
kv_cache_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda")
other_kv_index = 2048
kv_cache_seq_len[:] = seq_len
kv_cache_start_loc[0] = 0
kv_cache_start_loc[1] = seq_len
for i in range(Z):
kv_cache_loc[i, :] = torch.arange(i * seq_len, (i + 1) * seq_len, dtype=torch.int32, device="cuda")
Llama2TokenAttentionForwards.token_attn(
q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, other_kv_index
)
torch_out = torch_att(q, k, v, Z, seq_len, head_num, head_dim)
assert torch.allclose(torch_out, o, atol=1e-3, rtol=0)
if __name__ == "__main__":
test()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment