Unverified Commit 641b1ee7 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[devops] remove post commit ci (#5566)

* [devops] remove post commit ci

* [misc] run pre-commit on all files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 341263df
import os import os
from colossalqa.data_loader.document_loader import DocumentLoader from colossalqa.data_loader.document_loader import DocumentLoader
def test_add_document(): def test_add_document():
PATH = os.environ.get('TEST_DOCUMENT_LOADER_DATA_PATH') PATH = os.environ.get("TEST_DOCUMENT_LOADER_DATA_PATH")
files = [[PATH, 'all data']] files = [[PATH, "all data"]]
document_loader = DocumentLoader(files) document_loader = DocumentLoader(files)
documents = document_loader.all_data documents = document_loader.all_data
all_files = [] all_files = []
for doc in documents: for doc in documents:
assert isinstance(doc.page_content, str)==True assert isinstance(doc.page_content, str) == True
if doc.metadata['source'] not in all_files: if doc.metadata["source"] not in all_files:
all_files.append(doc.metadata['source']) all_files.append(doc.metadata["source"])
print(all_files) print(all_files)
assert len(all_files) == 6 assert len(all_files) == 6
if __name__=='__main__': if __name__ == "__main__":
test_add_document() test_add_document()
...@@ -4,56 +4,44 @@ from colossalqa.retrieval_conversation_universal import UniversalRetrievalConver ...@@ -4,56 +4,44 @@ from colossalqa.retrieval_conversation_universal import UniversalRetrievalConver
def test_en_retrievalQA(): def test_en_retrievalQA():
data_path_en = os.environ.get('TEST_DATA_PATH_EN') data_path_en = os.environ.get("TEST_DATA_PATH_EN")
data_path_zh = os.environ.get('TEST_DATA_PATH_ZH') data_path_zh = os.environ.get("TEST_DATA_PATH_ZH")
en_model_path = os.environ.get('EN_MODEL_PATH') en_model_path = os.environ.get("EN_MODEL_PATH")
zh_model_path = os.environ.get('ZH_MODEL_PATH') zh_model_path = os.environ.get("ZH_MODEL_PATH")
zh_model_name = os.environ.get('ZH_MODEL_NAME') zh_model_name = os.environ.get("ZH_MODEL_NAME")
en_model_name = os.environ.get('EN_MODEL_NAME') en_model_name = os.environ.get("EN_MODEL_NAME")
sql_file_path = os.environ.get('SQL_FILE_PATH') sql_file_path = os.environ.get("SQL_FILE_PATH")
qa_session = UniversalRetrievalConversation(files_en=[{ qa_session = UniversalRetrievalConversation(
'data_path': data_path_en, files_en=[{"data_path": data_path_en, "name": "company information", "separator": "\n"}],
'name': 'company information', files_zh=[{"data_path": data_path_zh, "name": "company information", "separator": "\n"}],
'separator': '\n' zh_model_path=zh_model_path,
}], en_model_path=en_model_path,
files_zh=[{ zh_model_name=zh_model_name,
'data_path': data_path_zh, en_model_name=en_model_name,
'name': 'company information', sql_file_path=sql_file_path,
'separator': '\n' )
}], ans = qa_session.run("which company runs business in hotel industry?", which_language="en")
zh_model_path=zh_model_path,
en_model_path=en_model_path,
zh_model_name=zh_model_name,
en_model_name=en_model_name,
sql_file_path=sql_file_path)
ans = qa_session.run("which company runs business in hotel industry?", which_language='en')
print(ans) print(ans)
def test_zh_retrievalQA(): def test_zh_retrievalQA():
data_path_en = os.environ.get('TEST_DATA_PATH_EN') data_path_en = os.environ.get("TEST_DATA_PATH_EN")
data_path_zh = os.environ.get('TEST_DATA_PATH_ZH') data_path_zh = os.environ.get("TEST_DATA_PATH_ZH")
en_model_path = os.environ.get('EN_MODEL_PATH') en_model_path = os.environ.get("EN_MODEL_PATH")
zh_model_path = os.environ.get('ZH_MODEL_PATH') zh_model_path = os.environ.get("ZH_MODEL_PATH")
zh_model_name = os.environ.get('ZH_MODEL_NAME') zh_model_name = os.environ.get("ZH_MODEL_NAME")
en_model_name = os.environ.get('EN_MODEL_NAME') en_model_name = os.environ.get("EN_MODEL_NAME")
sql_file_path = os.environ.get('SQL_FILE_PATH') sql_file_path = os.environ.get("SQL_FILE_PATH")
qa_session = UniversalRetrievalConversation(files_en=[{ qa_session = UniversalRetrievalConversation(
'data_path': data_path_en, files_en=[{"data_path": data_path_en, "name": "company information", "separator": "\n"}],
'name': 'company information', files_zh=[{"data_path": data_path_zh, "name": "company information", "separator": "\n"}],
'separator': '\n' zh_model_path=zh_model_path,
}], en_model_path=en_model_path,
files_zh=[{ zh_model_name=zh_model_name,
'data_path': data_path_zh, en_model_name=en_model_name,
'name': 'company information', sql_file_path=sql_file_path,
'separator': '\n' )
}], ans = qa_session.run("哪家公司在经营酒店业务?", which_language="zh")
zh_model_path=zh_model_path,
en_model_path=en_model_path,
zh_model_name=zh_model_name,
en_model_name=en_model_name,
sql_file_path=sql_file_path)
ans = qa_session.run("哪家公司在经营酒店业务?", which_language='zh')
print(ans) print(ans)
......
0.0.1 0.0.1
\ No newline at end of file
from .initialize import launch, launch_from_openmpi, launch_from_slurm, launch_from_torch
from . import accelerator from . import accelerator
from .initialize import launch, launch_from_openmpi, launch_from_slurm, launch_from_torch
try: try:
# .version will be created by setup.py # .version will be created by setup.py
......
...@@ -27,7 +27,7 @@ from torch.optim import Optimizer ...@@ -27,7 +27,7 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO, utils, CheckpointIndexFile from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
...@@ -93,9 +93,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): ...@@ -93,9 +93,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
Path(checkpoint_path).mkdir(parents=True, exist_ok=True) Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
with FSDP.state_dict_type( with FSDP.state_dict_type(
model.unwrap(), model.unwrap(), StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
): ):
state_dict = model.unwrap().state_dict() state_dict = model.unwrap().state_dict()
...@@ -172,7 +170,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): ...@@ -172,7 +170,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
with FSDP.state_dict_type( with FSDP.state_dict_type(
optimizer.unwrap_model().unwrap(), optimizer.unwrap_model().unwrap(),
StateDictType.FULL_STATE_DICT, StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True) FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
): ):
fsdp_optim_state = FSDP.full_optim_state_dict( fsdp_optim_state = FSDP.full_optim_state_dict(
optimizer.unwrap_model().unwrap(), optim=optimizer, rank0_only=True optimizer.unwrap_model().unwrap(), optim=optimizer, rank0_only=True
...@@ -241,7 +239,6 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): ...@@ -241,7 +239,6 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
) )
optimizer.load_state_dict(fsdp_state) optimizer.load_state_dict(fsdp_state)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
""" """
Save model to checkpoint but only on master process. Save model to checkpoint but only on master process.
......
...@@ -294,6 +294,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> ...@@ -294,6 +294,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
# Helper functions for saving state dict # Helper functions for saving state dict
# ====================================== # ======================================
def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None: def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None:
""" """
Save state dict to checkpoint. Save state dict to checkpoint.
...@@ -305,7 +306,7 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors ...@@ -305,7 +306,7 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
""" """
# Move all tensors in the state_dict to CPU before saving to avoid serialization issues # Move all tensors in the state_dict to CPU before saving to avoid serialization issues
state_dict_cpu = tree_map(lambda x: x.cpu() if torch.is_tensor(x) else x, state_dict) state_dict_cpu = tree_map(lambda x: x.cpu() if torch.is_tensor(x) else x, state_dict)
if use_safetensors: if use_safetensors:
assert is_safetensors_available(), "safetensors is not available." assert is_safetensors_available(), "safetensors is not available."
assert checkpoint_file_path.endswith( assert checkpoint_file_path.endswith(
......
...@@ -174,16 +174,20 @@ class ProcessGroupMesh: ...@@ -174,16 +174,20 @@ class ProcessGroupMesh:
List[Tuple[int, ...]]: Coordinates along the axis. List[Tuple[int, ...]]: Coordinates along the axis.
""" """
if isinstance(axis, int): if isinstance(axis, int):
axis = [axis,] axis = [
axis,
]
assert isinstance(indices_at_axis[0], int) assert isinstance(indices_at_axis[0], int)
indices_at_axis = [indices_at_axis,] indices_at_axis = [
indices_at_axis,
]
def add_index(base_coord, axis, indices_at_axis): def add_index(base_coord, axis, indices_at_axis):
coords_in_group = [] coords_in_group = []
for idx in indices_at_axis: for idx in indices_at_axis:
coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :]) coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :])
return coords_in_group return coords_in_group
coords_in_group = [base_coord] coords_in_group = [base_coord]
for ax, indices_at_ax in zip(axis, indices_at_axis): for ax, indices_at_ax in zip(axis, indices_at_axis):
new_coords_in_group = [] new_coords_in_group = []
...@@ -194,7 +198,10 @@ class ProcessGroupMesh: ...@@ -194,7 +198,10 @@ class ProcessGroupMesh:
return coords_in_group return coords_in_group
def create_group_along_axis( def create_group_along_axis(
self, axis: Union[int, List[int]], indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None, backend: Optional[str] = None self,
axis: Union[int, List[int]],
indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None,
backend: Optional[str] = None,
) -> ProcessGroup: ) -> ProcessGroup:
"""Create all process groups along the given axis, and return the one which the current process belongs to. """Create all process groups along the given axis, and return the one which the current process belongs to.
...@@ -207,11 +214,15 @@ class ProcessGroupMesh: ...@@ -207,11 +214,15 @@ class ProcessGroupMesh:
ProcessGroup: The process group along the given axis which the current process belongs to. ProcessGroup: The process group along the given axis which the current process belongs to.
""" """
if isinstance(axis, int): if isinstance(axis, int):
axis = [axis,] axis = [
axis,
]
if indices_at_axis is not None: if indices_at_axis is not None:
assert isinstance(indices_at_axis[0], int) assert isinstance(indices_at_axis[0], int)
indices_at_axis = [indices_at_axis,] indices_at_axis = [
indices_at_axis,
]
indices_at_axis = indices_at_axis or [list(range(self._shape[ax])) for ax in axis] indices_at_axis = indices_at_axis or [list(range(self._shape[ax])) for ax in axis]
reduced_shape = list(self._shape) reduced_shape = list(self._shape)
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis` # the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
......
...@@ -29,13 +29,17 @@ except: ...@@ -29,13 +29,17 @@ except:
try: try:
from colossalai.kernel.triton.flash_decoding import token_flash_decoding from colossalai.kernel.triton.flash_decoding import token_flash_decoding
HAS_TRITON_FLASH_DECODING_KERNEL = True HAS_TRITON_FLASH_DECODING_KERNEL = True
except: except:
print("no triton flash decoding support, please install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8") print(
"no triton flash decoding support, please install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8"
)
HAS_TRITON_FLASH_DECODING_KERNEL = False HAS_TRITON_FLASH_DECODING_KERNEL = False
try: try:
from flash_attn import flash_attn_with_kvcache from flash_attn import flash_attn_with_kvcache
HAS_FLASH_KERNEL = True HAS_FLASH_KERNEL = True
except: except:
HAS_FLASH_KERNEL = False HAS_FLASH_KERNEL = False
...@@ -48,6 +52,7 @@ def rotate_half(x): ...@@ -48,6 +52,7 @@ def rotate_half(x):
x2 = x[..., x.shape[-1] // 2 :] x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids): def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them. # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
...@@ -96,17 +101,22 @@ def llama_triton_context_attention( ...@@ -96,17 +101,22 @@ def llama_triton_context_attention(
infer_state.max_len_in_batch, infer_state.max_len_in_batch,
) )
def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1, q_head_num = -1, head_dim = -1):
def llama_triton_token_attention(
query_states, attn_output, infer_state, num_key_value_groups=1, q_head_num=-1, head_dim=-1
):
if HAS_TRITON_FLASH_DECODING_KERNEL and q_head_num != -1 and head_dim != -1: if HAS_TRITON_FLASH_DECODING_KERNEL and q_head_num != -1 and head_dim != -1:
token_flash_decoding(q = query_states, token_flash_decoding(
o_tensor = attn_output, q=query_states,
infer_state = infer_state, o_tensor=attn_output,
q_head_num = q_head_num, infer_state=infer_state,
head_dim = head_dim, q_head_num=q_head_num,
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], head_dim=head_dim,
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id]) cache_k=infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
return cache_v=infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
)
return
if num_key_value_groups == 1: if num_key_value_groups == 1:
token_attention_fwd( token_attention_fwd(
query_states, query_states,
...@@ -459,14 +469,15 @@ class LlamaInferenceForwards: ...@@ -459,14 +469,15 @@ class LlamaInferenceForwards:
) )
if HAS_LIGHTLLM_KERNEL: if HAS_LIGHTLLM_KERNEL:
attn_output = torch.empty_like(query_states) attn_output = torch.empty_like(query_states)
llama_triton_token_attention(query_states = query_states, llama_triton_token_attention(
attn_output = attn_output, query_states=query_states,
infer_state = infer_state, attn_output=attn_output,
num_key_value_groups = self.num_key_value_groups, infer_state=infer_state,
q_head_num = q_len * self.num_heads, num_key_value_groups=self.num_key_value_groups,
head_dim = self.head_dim) q_head_num=q_len * self.num_heads,
head_dim=self.head_dim,
)
else: else:
self.num_heads // self.num_key_value_heads self.num_heads // self.num_key_value_heads
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id] cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id]
......
...@@ -18,15 +18,15 @@ from .gptq_op import CaiGPTQLinearOp ...@@ -18,15 +18,15 @@ from .gptq_op import CaiGPTQLinearOp
HAS_GPTQ_CUDA = False HAS_GPTQ_CUDA = False
try: try:
from colossalai.kernel.op_builder.gptq import GPTQBuilder from colossalai.kernel.op_builder.gptq import GPTQBuilder
gptq_cuda = GPTQBuilder().load() gptq_cuda = GPTQBuilder().load()
HAS_GPTQ_CUDA = True HAS_GPTQ_CUDA = True
except ImportError: except ImportError:
warnings.warn('CUDA gptq is not installed') warnings.warn("CUDA gptq is not installed")
HAS_GPTQ_CUDA = False HAS_GPTQ_CUDA = False
class CaiQuantLinear(nn.Module): class CaiQuantLinear(nn.Module):
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
super().__init__() super().__init__()
if bits not in [2, 4, 8]: if bits not in [2, 4, 8]:
...@@ -37,23 +37,28 @@ class CaiQuantLinear(nn.Module): ...@@ -37,23 +37,28 @@ class CaiQuantLinear(nn.Module):
self.maxq = 2**self.bits - 1 self.maxq = 2**self.bits - 1
self.groupsize = groupsize if groupsize != -1 else infeatures self.groupsize = groupsize if groupsize != -1 else infeatures
self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) self.register_buffer("qweight", torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
self.register_buffer(
"qzeros",
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32),
)
self.register_buffer( self.register_buffer(
'qzeros', "scales", torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)) )
self.register_buffer('scales',
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
if row_split: if row_split:
self.register_buffer( self.register_buffer(
'g_idx', "g_idx",
torch.tensor([(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)], torch.tensor(
dtype=torch.int32)) [(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)], dtype=torch.int32
),
)
else: else:
self.register_buffer('g_idx', self.register_buffer(
torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)) "g_idx", torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)
)
if bias: if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16)) self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16))
else: else:
self.bias = None self.bias = None
...@@ -66,9 +71,11 @@ class CaiQuantLinear(nn.Module): ...@@ -66,9 +71,11 @@ class CaiQuantLinear(nn.Module):
self.row_split = row_split self.row_split = row_split
def pack(self, linear, scales, zeros, g_idx=None): def pack(self, linear, scales, zeros, g_idx=None):
g_idx = (
g_idx = g_idx.clone() if g_idx is not None else torch.tensor( g_idx.clone()
[i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32) if g_idx is not None
else torch.tensor([i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32)
)
scales = scales.t().contiguous() scales = scales.t().contiguous()
zeros = zeros.t().contiguous() zeros = zeros.t().contiguous()
...@@ -79,7 +86,6 @@ class CaiQuantLinear(nn.Module): ...@@ -79,7 +86,6 @@ class CaiQuantLinear(nn.Module):
if linear.bias is not None: if linear.bias is not None:
self.bias = linear.bias.clone().half() self.bias = linear.bias.clone().half()
wn = 8
pbits = 32 pbits = 32
ptype = torch.int32 ptype = torch.int32
unsign_type = np.uint32 unsign_type = np.uint32
...@@ -88,9 +94,10 @@ class CaiQuantLinear(nn.Module): ...@@ -88,9 +94,10 @@ class CaiQuantLinear(nn.Module):
intweight = [] intweight = []
for idx in range(self.infeatures): for idx in range(self.infeatures):
intweight.append( intweight.append(
torch.round( torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[
(linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:, :, None
None]) ]
)
intweight = torch.cat(intweight, dim=1) intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous() intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(unsign_type) intweight = intweight.numpy().astype(unsign_type)
...@@ -109,7 +116,7 @@ class CaiQuantLinear(nn.Module): ...@@ -109,7 +116,7 @@ class CaiQuantLinear(nn.Module):
raise NotImplementedError("Only 2,4,8 bits are supported.") raise NotImplementedError("Only 2,4,8 bits are supported.")
qweight = qweight.astype(sign_type) qweight = qweight.astype(sign_type)
qweight1 = torch.from_numpy(qweight) qweight1 = torch.from_numpy(qweight)
qweight1 = qweight1.contiguous() #.to("cuda") qweight1 = qweight1.contiguous() # .to("cuda")
self.qweight.data.copy_(qweight1) self.qweight.data.copy_(qweight1)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type) qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type)
...@@ -140,17 +147,20 @@ class CaiQuantLinear(nn.Module): ...@@ -140,17 +147,20 @@ class CaiQuantLinear(nn.Module):
self.q4_width = self.qweight.shape[1] self.q4_width = self.qweight.shape[1]
if self.g_idx is not None: if self.g_idx is not None:
if self.row_split and torch.equal( if self.row_split and torch.equal(
self.g_idx, self.g_idx,
torch.tensor( torch.tensor(
[(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)], [(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)],
dtype=torch.int32, dtype=torch.int32,
device=self.g_idx.device)): device=self.g_idx.device,
),
):
self.g_idx = None self.g_idx = None
elif torch.equal( elif torch.equal(
self.g_idx, self.g_idx,
torch.tensor([i // self.groupsize for i in range(self.infeatures)], torch.tensor(
dtype=torch.int32, [i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32, device=self.g_idx.device
device=self.g_idx.device)): ),
):
self.g_idx = None self.g_idx = None
if self.g_idx is not None: if self.g_idx is not None:
...@@ -165,7 +175,6 @@ class CaiQuantLinear(nn.Module): ...@@ -165,7 +175,6 @@ class CaiQuantLinear(nn.Module):
outshape = x.shape[:-1] + (self.outfeatures,) outshape = x.shape[:-1] + (self.outfeatures,)
if HAS_GPTQ_CUDA and self.bits == 4: if HAS_GPTQ_CUDA and self.bits == 4:
if self.q4 is None: if self.q4 is None:
self.init_q4() self.init_q4()
...@@ -191,7 +200,6 @@ class CaiQuantLinear(nn.Module): ...@@ -191,7 +200,6 @@ class CaiQuantLinear(nn.Module):
def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1): def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1):
qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1) qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1)
qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1) qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1)
scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1) scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1)
...@@ -203,24 +211,24 @@ def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1 ...@@ -203,24 +211,24 @@ def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1
zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num
for i in range(split_num): for i in range(split_num):
cai_linear.qweight[:, i * cai_split_out_features:(i + 1) * cai_linear.qweight[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = qweights[i][
cai_split_out_features] = qweights[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) * :, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
cai_split_out_features] ]
cai_linear.qzeros[:, i * zero_split_block:(i + 1) * cai_linear.qzeros[:, i * zero_split_block : (i + 1) * zero_split_block] = qzeros[i][
zero_split_block] = qzeros[i][:, tp_rank * zero_split_block:(tp_rank + 1) * zero_split_block] :, tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block
cai_linear.scales[:, i * cai_split_out_features:(i + 1) * ]
cai_split_out_features] = scales[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) * cai_linear.scales[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = scales[i][
cai_split_out_features] :, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
]
if cai_linear.bias is not None: if cai_linear.bias is not None:
cai_linear.bias[i * cai_split_out_features:(i + 1) * cai_linear.bias[i * cai_split_out_features : (i + 1) * cai_split_out_features] = bias[i][
cai_split_out_features] = bias[i][tp_rank * cai_split_out_features:(tp_rank + 1) * tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
cai_split_out_features] ]
cai_linear.g_idx.copy_(g_idx) cai_linear.g_idx.copy_(g_idx)
def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1): def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1):
qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0) qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0)
qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0) qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0)
scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0) scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0)
...@@ -231,47 +239,40 @@ def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1): ...@@ -231,47 +239,40 @@ def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1):
idx_split_features = cai_linear.infeatures // split_num idx_split_features = cai_linear.infeatures // split_num
for i in range(split_num): for i in range(split_num):
cai_linear.qweight[i * cai_split_in_features:(i + 1) * cai_linear.qweight[i * cai_split_in_features : (i + 1) * cai_split_in_features, :] = qweights[i][
cai_split_in_features, :] = qweights[i][tp_rank * cai_split_in_features:(tp_rank + 1) * tp_rank * cai_split_in_features : (tp_rank + 1) * cai_split_in_features, :
cai_split_in_features, :] ]
cai_linear.qzeros[i * zero_split_block:(i + 1) * cai_linear.qzeros[i * zero_split_block : (i + 1) * zero_split_block, :] = qzeros[i][
zero_split_block, :] = qzeros[i][tp_rank * zero_split_block:(tp_rank + 1) * tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, :
zero_split_block, :] ]
cai_linear.scales[i * zero_split_block:(i + 1) * cai_linear.scales[i * zero_split_block : (i + 1) * zero_split_block, :] = scales[i][
zero_split_block, :] = scales[i][tp_rank * zero_split_block:(tp_rank + 1) * tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, :
zero_split_block, :] ]
cai_linear.g_idx[i * idx_split_features:(i + 1) * cai_linear.g_idx[i * idx_split_features : (i + 1) * idx_split_features] = g_idxs[i][
idx_split_features] = g_idxs[i][tp_rank * idx_split_features:(tp_rank + 1) * tp_rank * idx_split_features : (tp_rank + 1) * idx_split_features
idx_split_features] ]
if cai_linear.bias is not None: if cai_linear.bias is not None:
cai_linear.bias.copy_(gptq_linear.bias) cai_linear.bias.copy_(gptq_linear.bias)
class RowCaiQuantLinear(CaiQuantLinear, ParallelModule): class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
super().__init__(
super().__init__(bits, bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split
groupsize, )
infeatures,
outfeatures,
bias,
tp_size=tp_size,
tp_rank=tp_rank,
row_split=row_split)
self.process_group = None self.process_group = None
@staticmethod @staticmethod
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, def from_native_module(
**kwargs) -> ParallelModule: module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
LazyInitContext.materialize(module) LazyInitContext.materialize(module)
# get the attributes # get the attributes
in_features = module.in_features in_features = module.in_features
# ensure only one process group is passed # ensure only one process group is passed
if isinstance(process_group, (list, tuple)): if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \ assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0] process_group = process_group[0]
tp_size = dist.get_world_size(process_group) tp_size = dist.get_world_size(process_group)
...@@ -282,15 +283,18 @@ class RowCaiQuantLinear(CaiQuantLinear, ParallelModule): ...@@ -282,15 +283,18 @@ class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
if in_features % tp_size != 0: if in_features % tp_size != 0:
raise ValueError( raise ValueError(
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!") f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
linear_1d = RowCaiQuantLinear(module.bits, )
module.group_size, linear_1d = RowCaiQuantLinear(
module.in_features // tp_size, module.bits,
module.out_features, module.group_size,
module.bias is not None, module.in_features // tp_size,
tp_size=tp_size, module.out_features,
tp_rank=tp_rank, module.bias is not None,
row_split=True) tp_size=tp_size,
tp_rank=tp_rank,
row_split=True,
)
linear_1d.process_group = process_group linear_1d.process_group = process_group
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
...@@ -306,30 +310,23 @@ class RowCaiQuantLinear(CaiQuantLinear, ParallelModule): ...@@ -306,30 +310,23 @@ class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
class ColCaiQuantLinear(CaiQuantLinear, ParallelModule): class ColCaiQuantLinear(CaiQuantLinear, ParallelModule):
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
super().__init__(
super().__init__(bits, bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split
groupsize, )
infeatures,
outfeatures,
bias,
tp_size=tp_size,
tp_rank=tp_rank,
row_split=row_split)
self.process_group = None self.process_group = None
@staticmethod @staticmethod
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, def from_native_module(
**kwargs) -> ParallelModule: module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
LazyInitContext.materialize(module) LazyInitContext.materialize(module)
# get the attributes # get the attributes
in_features = module.in_features in_features = module.in_features
# ensure only one process group is passed # ensure only one process group is passed
if isinstance(process_group, (list, tuple)): if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \ assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0] process_group = process_group[0]
tp_size = dist.get_world_size(process_group) tp_size = dist.get_world_size(process_group)
...@@ -340,14 +337,17 @@ class ColCaiQuantLinear(CaiQuantLinear, ParallelModule): ...@@ -340,14 +337,17 @@ class ColCaiQuantLinear(CaiQuantLinear, ParallelModule):
if in_features % tp_size != 0: if in_features % tp_size != 0:
raise ValueError( raise ValueError(
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!") f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
linear_1d = ColCaiQuantLinear(module.bits, )
module.group_size, linear_1d = ColCaiQuantLinear(
module.in_features, module.bits,
module.out_features // tp_size, module.group_size,
module.bias is not None, module.in_features,
tp_size=tp_size, module.out_features // tp_size,
tp_rank=tp_rank) module.bias is not None,
tp_size=tp_size,
tp_rank=tp_rank,
)
linear_1d.process_group = process_group linear_1d.process_group = process_group
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
......
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
try: try:
import triton import triton
import triton.language as tl import triton.language as tl
HAS_TRITON = True HAS_TRITON = True
except ImportError: except ImportError:
HAS_TRITON = False HAS_TRITON = False
...@@ -16,6 +17,7 @@ if HAS_TRITON: ...@@ -16,6 +17,7 @@ if HAS_TRITON:
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10 https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
""" """
if triton.__version__ < "2.1.0": if triton.__version__ < "2.1.0":
@triton.jit @triton.jit
def _context_flash_attention_kernel( def _context_flash_attention_kernel(
Q, Q,
...@@ -131,29 +133,47 @@ if HAS_TRITON: ...@@ -131,29 +133,47 @@ if HAS_TRITON:
m_i = m_i_new m_i = m_i_new
off_o = ( off_o = (
(cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od (cur_batch_start_index + offs_m[:, None]) * stride_obs
+ cur_head * stride_oh
+ offs_d[None, :] * stride_od
) )
out_ptrs = Out + off_o out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
return return
else: else:
# this function is modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L11 # this function is modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L11
@triton.jit @triton.jit
def _context_flash_attention_kernel_2( def _context_flash_attention_kernel_2(
Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen, Q,
Out, K,
kv_group_num, V,
stride_qbs, stride_qh, stride_qd, sm_scale,
stride_kbs, stride_kh, stride_kd, Alibi,
stride_vbs, stride_vh, stride_vd, B_Start_Loc,
stride_obs, stride_oh, stride_od, B_Seqlen,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, Out,
kv_group_num,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
): ):
cur_batch = tl.program_id(0) cur_batch = tl.program_id(0)
cur_head = tl.program_id(1) cur_head = tl.program_id(1)
start_m = tl.program_id(2) start_m = tl.program_id(2)
if kv_group_num is not None: if kv_group_num is not None:
cur_kv_head = cur_head // kv_group_num cur_kv_head = cur_head // kv_group_num
...@@ -166,7 +186,11 @@ if HAS_TRITON: ...@@ -166,7 +186,11 @@ if HAS_TRITON:
offs_n = tl.arange(0, BLOCK_N) offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL) offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
+ cur_head * stride_qh
+ offs_d[None, :] * stride_qd
)
if kv_group_num is None or kv_group_num == 1: if kv_group_num is None or kv_group_num == 1:
off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
...@@ -191,8 +215,11 @@ if HAS_TRITON: ...@@ -191,8 +215,11 @@ if HAS_TRITON:
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N) start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ---- # -- compute qk ----
k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, k = tl.load(
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0) k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
other=0.0,
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k) qk += tl.dot(q, k)
...@@ -220,8 +247,11 @@ if HAS_TRITON: ...@@ -220,8 +247,11 @@ if HAS_TRITON:
acc_scale = l_i / l_i_new * alpha acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None] acc = acc * acc_scale[:, None]
# update acc # update acc
v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, v = tl.load(
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
other=0.0,
)
p = p.to(v.dtype) p = p.to(v.dtype)
acc += tl.dot(p, v) acc += tl.dot(p, v)
...@@ -229,7 +259,11 @@ if HAS_TRITON: ...@@ -229,7 +259,11 @@ if HAS_TRITON:
l_i = l_i_new l_i = l_i_new
m_i = m_i_new m_i = m_i_new
# initialize pointers to output # initialize pointers to output
off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
+ cur_head * stride_oh
+ offs_d[None, :] * stride_od
)
out_ptrs = Out + off_o out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
return return
...@@ -249,7 +283,7 @@ if HAS_TRITON: ...@@ -249,7 +283,7 @@ if HAS_TRITON:
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
num_warps = 4 if Lk <= 64 else 8 num_warps = 4 if Lk <= 64 else 8
if triton.__version__ < "2.1.0": if triton.__version__ < "2.1.0":
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
_context_flash_attention_kernel[grid]( _context_flash_attention_kernel[grid](
...@@ -286,20 +320,26 @@ if HAS_TRITON: ...@@ -286,20 +320,26 @@ if HAS_TRITON:
) )
else: else:
_context_flash_attention_kernel_2[grid]( _context_flash_attention_kernel_2[grid](
q, k, v, sm_scale, alibi, b_start_loc, b_seq_len, q,
k,
v,
sm_scale,
alibi,
b_start_loc,
b_seq_len,
o, o,
None, None,
q.stride(0), q.stride(0),
q.stride(1), q.stride(1),
q.stride(2), q.stride(2),
k.stride(0), k.stride(0),
k.stride(1), k.stride(1),
k.stride(2), k.stride(2),
v.stride(0), v.stride(0),
v.stride(1), v.stride(1),
v.stride(2), v.stride(2),
o.stride(0), o.stride(0),
o.stride(1), o.stride(1),
o.stride(2), o.stride(2),
BLOCK_M=BLOCK, BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk, BLOCK_DMODEL=Lk,
...@@ -307,7 +347,7 @@ if HAS_TRITON: ...@@ -307,7 +347,7 @@ if HAS_TRITON:
num_warps=num_warps, num_warps=num_warps,
num_stages=1, num_stages=1,
) )
return return
@torch.no_grad() @torch.no_grad()
...@@ -327,7 +367,7 @@ if HAS_TRITON: ...@@ -327,7 +367,7 @@ if HAS_TRITON:
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8 num_warps = 4 if Lk <= 64 else 8
# num_warps = 4 # num_warps = 4
if triton.__version__ < "2.1.0": if triton.__version__ < "2.1.0":
_context_flash_attention_kernel[grid]( _context_flash_attention_kernel[grid](
q, q,
...@@ -337,7 +377,7 @@ if HAS_TRITON: ...@@ -337,7 +377,7 @@ if HAS_TRITON:
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
tmp, tmp,
None, None,
o, o,
q.stride(0), q.stride(0),
q.stride(1), q.stride(1),
...@@ -362,32 +402,33 @@ if HAS_TRITON: ...@@ -362,32 +402,33 @@ if HAS_TRITON:
) )
else: else:
kv_group_num = q.shape[1] // k.shape[1] kv_group_num = q.shape[1] // k.shape[1]
_context_flash_attention_kernel_2[grid]( _context_flash_attention_kernel_2[grid](
q, q,
k, k,
v, v,
sm_scale, sm_scale,
None, None,
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
o, o,
kv_group_num, kv_group_num,
q.stride(0), q.stride(0),
q.stride(1), q.stride(1),
q.stride(2), q.stride(2),
k.stride(0), k.stride(0),
k.stride(1), k.stride(1),
k.stride(2), k.stride(2),
v.stride(0), v.stride(0),
v.stride(1), v.stride(1),
v.stride(2), v.stride(2),
o.stride(0), o.stride(0),
o.stride(1), o.stride(1),
o.stride(2), o.stride(2),
BLOCK_M=BLOCK, BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk, BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK, BLOCK_N=BLOCK,
num_warps=num_warps, num_warps=num_warps,
num_stages=1,) num_stages=1,
)
return
\ No newline at end of file return
# adepted from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8/lightllm/models/llama/triton_kernel/flash_decoding.py # adepted from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8/lightllm/models/llama/triton_kernel/flash_decoding.py
import torch import torch
try: try:
from lightllm.models.llama.triton_kernel.flash_decoding_stage1 import flash_decode_stage1 from lightllm.models.llama.triton_kernel.flash_decoding_stage1 import flash_decode_stage1
from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2 from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2
HAS_LIGHTLLM_KERNEL = True HAS_LIGHTLLM_KERNEL = True
except: except:
print("install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8") print("install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8")
...@@ -10,41 +12,36 @@ except: ...@@ -10,41 +12,36 @@ except:
if HAS_LIGHTLLM_KERNEL: if HAS_LIGHTLLM_KERNEL:
def token_flash_decoding(q, o_tensor, infer_state, q_head_num, head_dim, cache_k, cache_v): def token_flash_decoding(q, o_tensor, infer_state, q_head_num, head_dim, cache_k, cache_v):
BLOCK_SEQ = 256 BLOCK_SEQ = 256
batch_size = infer_state.batch_size batch_size = infer_state.batch_size
max_len_in_batch = infer_state.max_len_in_batch max_len_in_batch = infer_state.max_len_in_batch
calcu_shape1 = (batch_size, q_head_num, head_dim) calcu_shape1 = (batch_size, q_head_num, head_dim)
if getattr(infer_state, 'mid_o', None) is None: if getattr(infer_state, "mid_o", None) is None:
infer_state.mid_o = torch.empty([batch_size, infer_state.mid_o = torch.empty(
q_head_num, [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim],
max_len_in_batch // BLOCK_SEQ + 1, dtype=torch.float32,
head_dim], device="cuda",
dtype=torch.float32, )
device="cuda") infer_state.mid_o_logexpsum = torch.empty(
infer_state.mid_o_logexpsum = torch.empty([batch_size, [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda"
q_head_num, )
max_len_in_batch // BLOCK_SEQ + 1],
dtype=torch.float32,
device="cuda")
mid_o = infer_state.mid_o mid_o = infer_state.mid_o
mid_o_logexpsum = infer_state.mid_o_logexpsum mid_o_logexpsum = infer_state.mid_o_logexpsum
flash_decode_stage1(q.view(calcu_shape1), flash_decode_stage1(
cache_k, q.view(calcu_shape1),
cache_v, cache_k,
infer_state.block_loc, cache_v,
infer_state.seq_len, infer_state.block_loc,
infer_state.max_len_in_batch, infer_state.seq_len,
mid_o, infer_state.max_len_in_batch,
mid_o_logexpsum, mid_o,
BLOCK_SEQ) mid_o_logexpsum,
flash_decode_stage2(mid_o, BLOCK_SEQ,
mid_o_logexpsum, )
infer_state.seq_len, flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ)
o_tensor.view(calcu_shape1),
BLOCK_SEQ)
...@@ -8,6 +8,7 @@ from torch.cuda.amp import custom_bwd, custom_fwd ...@@ -8,6 +8,7 @@ from torch.cuda.amp import custom_bwd, custom_fwd
try: try:
import triton import triton
import triton.language as tl import triton.language as tl
HAS_TRITON = True HAS_TRITON = True
except ImportError: except ImportError:
HAS_TRITON = False HAS_TRITON = False
...@@ -26,8 +27,8 @@ if HAS_TRITON: ...@@ -26,8 +27,8 @@ if HAS_TRITON:
X_GATE2, X_GATE2,
X_UP, X_UP,
Y, Y,
stride, # how much to increase the pointer when moving by 1 row stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X N, # number of columns in X
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
# Map the program id to the row of X and Y it should compute. # Map the program id to the row of X and Y it should compute.
...@@ -41,9 +42,9 @@ if HAS_TRITON: ...@@ -41,9 +42,9 @@ if HAS_TRITON:
for off in range(0, N, BLOCK_SIZE): for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE) cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N mask = cols < N
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.) x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.0)
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.) x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.0)
x_up = tl.load(X_UP + cols, mask=mask, other=0.) x_up = tl.load(X_UP + cols, mask=mask, other=0.0)
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype) x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
y = x_gate1 * x_gate2 * x_gate2_sigmoid * x_up y = x_gate1 * x_gate2 * x_gate2_sigmoid * x_up
# Write output # Write output
...@@ -58,8 +59,8 @@ if HAS_TRITON: ...@@ -58,8 +59,8 @@ if HAS_TRITON:
X_GATE2_GRAD, X_GATE2_GRAD,
X_UP_GRAD, X_UP_GRAD,
Y_GRAD, Y_GRAD,
stride, # how much to increase the pointer when moving by 1 row stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X N, # number of columns in X
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
# Map the program id to the row of X and Y it should compute. # Map the program id to the row of X and Y it should compute.
...@@ -76,10 +77,10 @@ if HAS_TRITON: ...@@ -76,10 +77,10 @@ if HAS_TRITON:
for off in range(0, N, BLOCK_SIZE): for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE) cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N mask = cols < N
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.) x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.0)
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.) x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.0)
x_up = tl.load(X_UP + cols, mask=mask, other=0.) x_up = tl.load(X_UP + cols, mask=mask, other=0.0)
y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.) y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.0)
# forward: y = x_gate1 * x_gate2 * tl.sigmoid(x_gate2) * x_up # forward: y = x_gate1 * x_gate2 * tl.sigmoid(x_gate2) * x_up
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype) x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
...@@ -147,14 +148,9 @@ if HAS_TRITON: ...@@ -147,14 +148,9 @@ if HAS_TRITON:
# restore setting # restore setting
ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps = M, N, BLOCK_SIZE, num_warps ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps = M, N, BLOCK_SIZE, num_warps
# enqueue kernel # enqueue kernel
_llama_act_combine_forward[(M,)](x_gate1, _llama_act_combine_forward[(M,)](
x_gate2, x_gate1, x_gate2, x_up, y, x_up.stride(-2), N, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
x_up, )
y,
x_up.stride(-2),
N,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps)
return y return y
@staticmethod @staticmethod
...@@ -166,20 +162,25 @@ if HAS_TRITON: ...@@ -166,20 +162,25 @@ if HAS_TRITON:
# init grad # init grad
y_grad = grad_outputs[0] y_grad = grad_outputs[0]
x_gate1_grad, x_gate2_grad, x_up_grad = torch.empty_like(x_gate1), torch.empty_like( x_gate1_grad, x_gate2_grad, x_up_grad = (
x_gate2), torch.empty_like(x_up) torch.empty_like(x_gate1),
torch.empty_like(x_gate2),
torch.empty_like(x_up),
)
# enqueue kernel # enqueue kernel
_llama_act_combine_backward[(M,)](x_gate1, _llama_act_combine_backward[(M,)](
x_gate2, x_gate1,
x_up, x_gate2,
x_gate1_grad, x_up,
x_gate2_grad, x_gate1_grad,
x_up_grad, x_gate2_grad,
y_grad, x_up_grad,
x_up.stride(-2), y_grad,
N, x_up.stride(-2),
BLOCK_SIZE=BLOCK_SIZE, N,
num_warps=num_warps) BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
x_gate_grad = torch.cat([x_gate1_grad, x_gate2_grad], dim=-1) x_gate_grad = torch.cat([x_gate1_grad, x_gate2_grad], dim=-1)
return x_gate_grad, x_up_grad, None, None return x_gate_grad, x_up_grad, None, None
...@@ -13,10 +13,18 @@ except ImportError: ...@@ -13,10 +13,18 @@ except ImportError:
print("please install triton from https://github.com/openai/triton") print("please install triton from https://github.com/openai/triton")
try: try:
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fwd2 from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import (
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd token_att_fwd as lightllm_bloom_token_att_fwd,
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd )
from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_bloom_token_att_fwd from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import (
token_att_fwd as lightllm_llama_token_att_fwd,
)
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import (
token_att_fwd2 as lightllm_llama_token_att_fwd2,
)
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import (
token_softmax_fwd as lightllm_llama_token_softmax_fwd,
)
HAS_TRITON_TOKEN_ATTENTION = True HAS_TRITON_TOKEN_ATTENTION = True
except ImportError: except ImportError:
...@@ -205,9 +213,7 @@ class Llama2TokenAttentionForwards: ...@@ -205,9 +213,7 @@ class Llama2TokenAttentionForwards:
if triton.__version__ == "2.0.0": if triton.__version__ == "2.0.0":
prob = torch.empty_like(att_m_tensor) prob = torch.empty_like(att_m_tensor)
lightllm_llama_token_softmax_fwd( lightllm_llama_token_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch)
att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch
)
att_m_tensor = None att_m_tensor = None
lightllm_llama_token_att_fwd2( lightllm_llama_token_att_fwd2(
......
...@@ -8,7 +8,9 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecode ...@@ -8,7 +8,9 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecode
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
from ._utils import copy_kv_to_mem_cache from ._utils import copy_kv_to_mem_cache
try: try:
from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( from lightllm.models.llama.triton_kernel.context_flashattention_nopad import (
context_attention_fwd as lightllm_llama_context_attention_fwd, context_attention_fwd as lightllm_llama_context_attention_fwd,
...@@ -90,7 +92,7 @@ def llama_triton_token_attention(query_states, attn_output, infer_state, num_key ...@@ -90,7 +92,7 @@ def llama_triton_token_attention(query_states, attn_output, infer_state, num_key
# infer_state.cache_manager.past_key_values_length, # infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch, infer_state.max_len_in_batch,
) )
else: else:
Llama2TokenAttentionForwards.token_attn( Llama2TokenAttentionForwards.token_attn(
query_states, query_states,
......
from .attn import AttnMaskType, ColoAttention
from ._operation import all_to_all_comm from ._operation import all_to_all_comm
from .attn import AttnMaskType, ColoAttention
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, VocabParallelEmbedding1D from .embedding import Embedding1D, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row from .linear import Linear1D_Col, Linear1D_Row
......
...@@ -2,13 +2,13 @@ from .api import ( ...@@ -2,13 +2,13 @@ from .api import (
compute_global_numel, compute_global_numel,
customized_distributed_tensor_to_param, customized_distributed_tensor_to_param,
distribute_tensor, distribute_tensor,
init_as_dtensor,
distribute_tensor_with_customization, distribute_tensor_with_customization,
init_tensor_as_customization_distributed,
get_device_mesh, get_device_mesh,
get_global_shape, get_global_shape,
get_layout, get_layout,
get_sharding_spec, get_sharding_spec,
init_as_dtensor,
init_tensor_as_customization_distributed,
is_customized_distributed_tensor, is_customized_distributed_tensor,
is_distributed_tensor, is_distributed_tensor,
is_sharded, is_sharded,
......
...@@ -128,7 +128,10 @@ def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_sp ...@@ -128,7 +128,10 @@ def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_sp
return sharded_tensor return sharded_tensor
def init_as_dtensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size) -> torch.Tensor:
def init_as_dtensor(
tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size
) -> torch.Tensor:
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor." assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape) dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
...@@ -140,6 +143,7 @@ def init_as_dtensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec ...@@ -140,6 +143,7 @@ def init_as_dtensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec
return tensor return tensor
def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None: def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None:
""" """
Convert the layout of the tensor from source_spec to target_spec. Convert the layout of the tensor from source_spec to target_spec.
...@@ -468,7 +472,6 @@ def init_tensor_as_customization_distributed(tensor: torch.Tensor, shard_fn, gat ...@@ -468,7 +472,6 @@ def init_tensor_as_customization_distributed(tensor: torch.Tensor, shard_fn, gat
assert callable(gather_fn), "The gather_fn must be callable." assert callable(gather_fn), "The gather_fn must be callable."
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor." assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
# set the shard_fn and gather_fn as attributes of the distributed tensor # set the shard_fn and gather_fn as attributes of the distributed tensor
tensor.shard_fn = shard_fn tensor.shard_fn = shard_fn
tensor.gather_fn = gather_fn tensor.gather_fn = gather_fn
......
...@@ -190,6 +190,7 @@ def calculate_global_norm_from_list(norm_list): ...@@ -190,6 +190,7 @@ def calculate_global_norm_from_list(norm_list):
total_norm += norm**2.0 total_norm += norm**2.0
return math.sqrt(total_norm) return math.sqrt(total_norm)
def sync_tensor(flat_tensor, tensor_list): def sync_tensor(flat_tensor, tensor_list):
""" """
Synchronize the flattened tensor and unflattened tensor list. When Synchronize the flattened tensor and unflattened tensor list. When
......
...@@ -220,7 +220,7 @@ model, optimizer, _criterion, train_dataloader, lr_scheduler = booster.boost( ...@@ -220,7 +220,7 @@ model, optimizer, _criterion, train_dataloader, lr_scheduler = booster.boost(
) )
``` ```
## 使用混合并行训练 ViT ## 使用混合并行训练 ViT
最后就可以使用混合并行策略来训练模型了。我们先定义一个训练函数,描述训练过程。需要注意的是,如果使用了管道并行策略,需要调用`booster.execute_pipeline`来执行模型的训练,它会调用`scheduler`管理模型的前后向操作。 最后就可以使用混合并行策略来训练模型了。我们先定义一个训练函数,描述训练过程。需要注意的是,如果使用了管道并行策略,需要调用`booster.execute_pipeline`来执行模型的训练,它会调用`scheduler`管理模型的前后向操作。
```python ```python
def run_forward_backward( def run_forward_backward(
model: nn.Module, model: nn.Module,
......
...@@ -119,9 +119,7 @@ def main(): ...@@ -119,9 +119,7 @@ def main():
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None: if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
# run pipeline forward backward # run pipeline forward backward
batch = iter([batch]) batch = iter([batch])
outputs = booster.execute_pipeline( outputs = booster.execute_pipeline(batch, model, criterion, optimizer, return_loss=True)
batch, model, criterion, optimizer, return_loss=True
)
else: else:
outputs = model(**batch) outputs = model(**batch)
loss = criterion(outputs, None) loss = criterion(outputs, None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment