Unverified Commit 641b1ee7 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[devops] remove post commit ci (#5566)

* [devops] remove post commit ci

* [misc] run pre-commit on all files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 341263df
import os
from colossalqa.data_loader.document_loader import DocumentLoader
def test_add_document():
PATH = os.environ.get('TEST_DOCUMENT_LOADER_DATA_PATH')
files = [[PATH, 'all data']]
PATH = os.environ.get("TEST_DOCUMENT_LOADER_DATA_PATH")
files = [[PATH, "all data"]]
document_loader = DocumentLoader(files)
documents = document_loader.all_data
all_files = []
for doc in documents:
assert isinstance(doc.page_content, str)==True
if doc.metadata['source'] not in all_files:
all_files.append(doc.metadata['source'])
assert isinstance(doc.page_content, str) == True
if doc.metadata["source"] not in all_files:
all_files.append(doc.metadata["source"])
print(all_files)
assert len(all_files) == 6
if __name__=='__main__':
if __name__ == "__main__":
test_add_document()
......@@ -4,56 +4,44 @@ from colossalqa.retrieval_conversation_universal import UniversalRetrievalConver
def test_en_retrievalQA():
data_path_en = os.environ.get('TEST_DATA_PATH_EN')
data_path_zh = os.environ.get('TEST_DATA_PATH_ZH')
en_model_path = os.environ.get('EN_MODEL_PATH')
zh_model_path = os.environ.get('ZH_MODEL_PATH')
zh_model_name = os.environ.get('ZH_MODEL_NAME')
en_model_name = os.environ.get('EN_MODEL_NAME')
sql_file_path = os.environ.get('SQL_FILE_PATH')
qa_session = UniversalRetrievalConversation(files_en=[{
'data_path': data_path_en,
'name': 'company information',
'separator': '\n'
}],
files_zh=[{
'data_path': data_path_zh,
'name': 'company information',
'separator': '\n'
}],
zh_model_path=zh_model_path,
en_model_path=en_model_path,
zh_model_name=zh_model_name,
en_model_name=en_model_name,
sql_file_path=sql_file_path)
ans = qa_session.run("which company runs business in hotel industry?", which_language='en')
data_path_en = os.environ.get("TEST_DATA_PATH_EN")
data_path_zh = os.environ.get("TEST_DATA_PATH_ZH")
en_model_path = os.environ.get("EN_MODEL_PATH")
zh_model_path = os.environ.get("ZH_MODEL_PATH")
zh_model_name = os.environ.get("ZH_MODEL_NAME")
en_model_name = os.environ.get("EN_MODEL_NAME")
sql_file_path = os.environ.get("SQL_FILE_PATH")
qa_session = UniversalRetrievalConversation(
files_en=[{"data_path": data_path_en, "name": "company information", "separator": "\n"}],
files_zh=[{"data_path": data_path_zh, "name": "company information", "separator": "\n"}],
zh_model_path=zh_model_path,
en_model_path=en_model_path,
zh_model_name=zh_model_name,
en_model_name=en_model_name,
sql_file_path=sql_file_path,
)
ans = qa_session.run("which company runs business in hotel industry?", which_language="en")
print(ans)
def test_zh_retrievalQA():
data_path_en = os.environ.get('TEST_DATA_PATH_EN')
data_path_zh = os.environ.get('TEST_DATA_PATH_ZH')
en_model_path = os.environ.get('EN_MODEL_PATH')
zh_model_path = os.environ.get('ZH_MODEL_PATH')
zh_model_name = os.environ.get('ZH_MODEL_NAME')
en_model_name = os.environ.get('EN_MODEL_NAME')
sql_file_path = os.environ.get('SQL_FILE_PATH')
qa_session = UniversalRetrievalConversation(files_en=[{
'data_path': data_path_en,
'name': 'company information',
'separator': '\n'
}],
files_zh=[{
'data_path': data_path_zh,
'name': 'company information',
'separator': '\n'
}],
zh_model_path=zh_model_path,
en_model_path=en_model_path,
zh_model_name=zh_model_name,
en_model_name=en_model_name,
sql_file_path=sql_file_path)
ans = qa_session.run("哪家公司在经营酒店业务?", which_language='zh')
data_path_en = os.environ.get("TEST_DATA_PATH_EN")
data_path_zh = os.environ.get("TEST_DATA_PATH_ZH")
en_model_path = os.environ.get("EN_MODEL_PATH")
zh_model_path = os.environ.get("ZH_MODEL_PATH")
zh_model_name = os.environ.get("ZH_MODEL_NAME")
en_model_name = os.environ.get("EN_MODEL_NAME")
sql_file_path = os.environ.get("SQL_FILE_PATH")
qa_session = UniversalRetrievalConversation(
files_en=[{"data_path": data_path_en, "name": "company information", "separator": "\n"}],
files_zh=[{"data_path": data_path_zh, "name": "company information", "separator": "\n"}],
zh_model_path=zh_model_path,
en_model_path=en_model_path,
zh_model_name=zh_model_name,
en_model_name=en_model_name,
sql_file_path=sql_file_path,
)
ans = qa_session.run("哪家公司在经营酒店业务?", which_language="zh")
print(ans)
......
0.0.1
\ No newline at end of file
0.0.1
from .initialize import launch, launch_from_openmpi, launch_from_slurm, launch_from_torch
from . import accelerator
from .initialize import launch, launch_from_openmpi, launch_from_slurm, launch_from_torch
try:
# .version will be created by setup.py
......
......@@ -27,7 +27,7 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO, utils, CheckpointIndexFile
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
......@@ -93,9 +93,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
with FSDP.state_dict_type(
model.unwrap(),
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
model.unwrap(), StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
):
state_dict = model.unwrap().state_dict()
......@@ -172,7 +170,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
with FSDP.state_dict_type(
optimizer.unwrap_model().unwrap(),
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
fsdp_optim_state = FSDP.full_optim_state_dict(
optimizer.unwrap_model().unwrap(), optim=optimizer, rank0_only=True
......@@ -241,7 +239,6 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
)
optimizer.load_state_dict(fsdp_state)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Save model to checkpoint but only on master process.
......
......@@ -294,6 +294,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
# Helper functions for saving state dict
# ======================================
def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None:
"""
Save state dict to checkpoint.
......@@ -305,7 +306,7 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
"""
# Move all tensors in the state_dict to CPU before saving to avoid serialization issues
state_dict_cpu = tree_map(lambda x: x.cpu() if torch.is_tensor(x) else x, state_dict)
if use_safetensors:
assert is_safetensors_available(), "safetensors is not available."
assert checkpoint_file_path.endswith(
......
......@@ -174,16 +174,20 @@ class ProcessGroupMesh:
List[Tuple[int, ...]]: Coordinates along the axis.
"""
if isinstance(axis, int):
axis = [axis,]
axis = [
axis,
]
assert isinstance(indices_at_axis[0], int)
indices_at_axis = [indices_at_axis,]
indices_at_axis = [
indices_at_axis,
]
def add_index(base_coord, axis, indices_at_axis):
coords_in_group = []
for idx in indices_at_axis:
coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :])
return coords_in_group
coords_in_group = [base_coord]
for ax, indices_at_ax in zip(axis, indices_at_axis):
new_coords_in_group = []
......@@ -194,7 +198,10 @@ class ProcessGroupMesh:
return coords_in_group
def create_group_along_axis(
self, axis: Union[int, List[int]], indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None, backend: Optional[str] = None
self,
axis: Union[int, List[int]],
indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None,
backend: Optional[str] = None,
) -> ProcessGroup:
"""Create all process groups along the given axis, and return the one which the current process belongs to.
......@@ -207,11 +214,15 @@ class ProcessGroupMesh:
ProcessGroup: The process group along the given axis which the current process belongs to.
"""
if isinstance(axis, int):
axis = [axis,]
axis = [
axis,
]
if indices_at_axis is not None:
assert isinstance(indices_at_axis[0], int)
indices_at_axis = [indices_at_axis,]
indices_at_axis = [
indices_at_axis,
]
indices_at_axis = indices_at_axis or [list(range(self._shape[ax])) for ax in axis]
reduced_shape = list(self._shape)
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
......
......@@ -29,13 +29,17 @@ except:
try:
from colossalai.kernel.triton.flash_decoding import token_flash_decoding
HAS_TRITON_FLASH_DECODING_KERNEL = True
except:
print("no triton flash decoding support, please install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8")
print(
"no triton flash decoding support, please install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8"
)
HAS_TRITON_FLASH_DECODING_KERNEL = False
try:
from flash_attn import flash_attn_with_kvcache
HAS_FLASH_KERNEL = True
except:
HAS_FLASH_KERNEL = False
......@@ -48,6 +52,7 @@ def rotate_half(x):
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
......@@ -96,17 +101,22 @@ def llama_triton_context_attention(
infer_state.max_len_in_batch,
)
def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1, q_head_num = -1, head_dim = -1):
def llama_triton_token_attention(
query_states, attn_output, infer_state, num_key_value_groups=1, q_head_num=-1, head_dim=-1
):
if HAS_TRITON_FLASH_DECODING_KERNEL and q_head_num != -1 and head_dim != -1:
token_flash_decoding(q = query_states,
o_tensor = attn_output,
infer_state = infer_state,
q_head_num = q_head_num,
head_dim = head_dim,
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id])
return
token_flash_decoding(
q=query_states,
o_tensor=attn_output,
infer_state=infer_state,
q_head_num=q_head_num,
head_dim=head_dim,
cache_k=infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
cache_v=infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
)
return
if num_key_value_groups == 1:
token_attention_fwd(
query_states,
......@@ -459,14 +469,15 @@ class LlamaInferenceForwards:
)
if HAS_LIGHTLLM_KERNEL:
attn_output = torch.empty_like(query_states)
llama_triton_token_attention(query_states = query_states,
attn_output = attn_output,
infer_state = infer_state,
num_key_value_groups = self.num_key_value_groups,
q_head_num = q_len * self.num_heads,
head_dim = self.head_dim)
llama_triton_token_attention(
query_states=query_states,
attn_output=attn_output,
infer_state=infer_state,
num_key_value_groups=self.num_key_value_groups,
q_head_num=q_len * self.num_heads,
head_dim=self.head_dim,
)
else:
self.num_heads // self.num_key_value_heads
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id]
......
......@@ -18,15 +18,15 @@ from .gptq_op import CaiGPTQLinearOp
HAS_GPTQ_CUDA = False
try:
from colossalai.kernel.op_builder.gptq import GPTQBuilder
gptq_cuda = GPTQBuilder().load()
HAS_GPTQ_CUDA = True
except ImportError:
warnings.warn('CUDA gptq is not installed')
warnings.warn("CUDA gptq is not installed")
HAS_GPTQ_CUDA = False
class CaiQuantLinear(nn.Module):
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
super().__init__()
if bits not in [2, 4, 8]:
......@@ -37,23 +37,28 @@ class CaiQuantLinear(nn.Module):
self.maxq = 2**self.bits - 1
self.groupsize = groupsize if groupsize != -1 else infeatures
self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
self.register_buffer("qweight", torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
self.register_buffer(
"qzeros",
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32),
)
self.register_buffer(
'qzeros',
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32))
self.register_buffer('scales',
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
"scales", torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)
)
if row_split:
self.register_buffer(
'g_idx',
torch.tensor([(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)],
dtype=torch.int32))
"g_idx",
torch.tensor(
[(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)], dtype=torch.int32
),
)
else:
self.register_buffer('g_idx',
torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))
self.register_buffer(
"g_idx", torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)
)
if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16))
else:
self.bias = None
......@@ -66,9 +71,11 @@ class CaiQuantLinear(nn.Module):
self.row_split = row_split
def pack(self, linear, scales, zeros, g_idx=None):
g_idx = g_idx.clone() if g_idx is not None else torch.tensor(
[i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32)
g_idx = (
g_idx.clone()
if g_idx is not None
else torch.tensor([i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32)
)
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
......@@ -79,7 +86,6 @@ class CaiQuantLinear(nn.Module):
if linear.bias is not None:
self.bias = linear.bias.clone().half()
wn = 8
pbits = 32
ptype = torch.int32
unsign_type = np.uint32
......@@ -88,9 +94,10 @@ class CaiQuantLinear(nn.Module):
intweight = []
for idx in range(self.infeatures):
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:,
None])
torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[
:, None
]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(unsign_type)
......@@ -109,7 +116,7 @@ class CaiQuantLinear(nn.Module):
raise NotImplementedError("Only 2,4,8 bits are supported.")
qweight = qweight.astype(sign_type)
qweight1 = torch.from_numpy(qweight)
qweight1 = qweight1.contiguous() #.to("cuda")
qweight1 = qweight1.contiguous() # .to("cuda")
self.qweight.data.copy_(qweight1)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type)
......@@ -140,17 +147,20 @@ class CaiQuantLinear(nn.Module):
self.q4_width = self.qweight.shape[1]
if self.g_idx is not None:
if self.row_split and torch.equal(
self.g_idx,
torch.tensor(
[(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)],
dtype=torch.int32,
device=self.g_idx.device)):
self.g_idx,
torch.tensor(
[(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)],
dtype=torch.int32,
device=self.g_idx.device,
),
):
self.g_idx = None
elif torch.equal(
self.g_idx,
torch.tensor([i // self.groupsize for i in range(self.infeatures)],
dtype=torch.int32,
device=self.g_idx.device)):
self.g_idx,
torch.tensor(
[i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32, device=self.g_idx.device
),
):
self.g_idx = None
if self.g_idx is not None:
......@@ -165,7 +175,6 @@ class CaiQuantLinear(nn.Module):
outshape = x.shape[:-1] + (self.outfeatures,)
if HAS_GPTQ_CUDA and self.bits == 4:
if self.q4 is None:
self.init_q4()
......@@ -191,7 +200,6 @@ class CaiQuantLinear(nn.Module):
def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1):
qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1)
qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1)
scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1)
......@@ -203,24 +211,24 @@ def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1
zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num
for i in range(split_num):
cai_linear.qweight[:, i * cai_split_out_features:(i + 1) *
cai_split_out_features] = qweights[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) *
cai_split_out_features]
cai_linear.qzeros[:, i * zero_split_block:(i + 1) *
zero_split_block] = qzeros[i][:, tp_rank * zero_split_block:(tp_rank + 1) * zero_split_block]
cai_linear.scales[:, i * cai_split_out_features:(i + 1) *
cai_split_out_features] = scales[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) *
cai_split_out_features]
cai_linear.qweight[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = qweights[i][
:, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
]
cai_linear.qzeros[:, i * zero_split_block : (i + 1) * zero_split_block] = qzeros[i][
:, tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block
]
cai_linear.scales[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = scales[i][
:, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
]
if cai_linear.bias is not None:
cai_linear.bias[i * cai_split_out_features:(i + 1) *
cai_split_out_features] = bias[i][tp_rank * cai_split_out_features:(tp_rank + 1) *
cai_split_out_features]
cai_linear.bias[i * cai_split_out_features : (i + 1) * cai_split_out_features] = bias[i][
tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
]
cai_linear.g_idx.copy_(g_idx)
def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1):
qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0)
qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0)
scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0)
......@@ -231,47 +239,40 @@ def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1):
idx_split_features = cai_linear.infeatures // split_num
for i in range(split_num):
cai_linear.qweight[i * cai_split_in_features:(i + 1) *
cai_split_in_features, :] = qweights[i][tp_rank * cai_split_in_features:(tp_rank + 1) *
cai_split_in_features, :]
cai_linear.qzeros[i * zero_split_block:(i + 1) *
zero_split_block, :] = qzeros[i][tp_rank * zero_split_block:(tp_rank + 1) *
zero_split_block, :]
cai_linear.scales[i * zero_split_block:(i + 1) *
zero_split_block, :] = scales[i][tp_rank * zero_split_block:(tp_rank + 1) *
zero_split_block, :]
cai_linear.g_idx[i * idx_split_features:(i + 1) *
idx_split_features] = g_idxs[i][tp_rank * idx_split_features:(tp_rank + 1) *
idx_split_features]
cai_linear.qweight[i * cai_split_in_features : (i + 1) * cai_split_in_features, :] = qweights[i][
tp_rank * cai_split_in_features : (tp_rank + 1) * cai_split_in_features, :
]
cai_linear.qzeros[i * zero_split_block : (i + 1) * zero_split_block, :] = qzeros[i][
tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, :
]
cai_linear.scales[i * zero_split_block : (i + 1) * zero_split_block, :] = scales[i][
tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, :
]
cai_linear.g_idx[i * idx_split_features : (i + 1) * idx_split_features] = g_idxs[i][
tp_rank * idx_split_features : (tp_rank + 1) * idx_split_features
]
if cai_linear.bias is not None:
cai_linear.bias.copy_(gptq_linear.bias)
class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
super().__init__(bits,
groupsize,
infeatures,
outfeatures,
bias,
tp_size=tp_size,
tp_rank=tp_rank,
row_split=row_split)
super().__init__(
bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split
)
self.process_group = None
@staticmethod
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
**kwargs) -> ParallelModule:
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \
f'Expected only one process group, got {len(process_group)}.'
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
......@@ -282,15 +283,18 @@ class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
if in_features % tp_size != 0:
raise ValueError(
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
linear_1d = RowCaiQuantLinear(module.bits,
module.group_size,
module.in_features // tp_size,
module.out_features,
module.bias is not None,
tp_size=tp_size,
tp_rank=tp_rank,
row_split=True)
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
)
linear_1d = RowCaiQuantLinear(
module.bits,
module.group_size,
module.in_features // tp_size,
module.out_features,
module.bias is not None,
tp_size=tp_size,
tp_rank=tp_rank,
row_split=True,
)
linear_1d.process_group = process_group
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
......@@ -306,30 +310,23 @@ class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
class ColCaiQuantLinear(CaiQuantLinear, ParallelModule):
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
super().__init__(bits,
groupsize,
infeatures,
outfeatures,
bias,
tp_size=tp_size,
tp_rank=tp_rank,
row_split=row_split)
super().__init__(
bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split
)
self.process_group = None
@staticmethod
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
**kwargs) -> ParallelModule:
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \
f'Expected only one process group, got {len(process_group)}.'
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
......@@ -340,14 +337,17 @@ class ColCaiQuantLinear(CaiQuantLinear, ParallelModule):
if in_features % tp_size != 0:
raise ValueError(
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
linear_1d = ColCaiQuantLinear(module.bits,
module.group_size,
module.in_features,
module.out_features // tp_size,
module.bias is not None,
tp_size=tp_size,
tp_rank=tp_rank)
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
)
linear_1d = ColCaiQuantLinear(
module.bits,
module.group_size,
module.in_features,
module.out_features // tp_size,
module.bias is not None,
tp_size=tp_size,
tp_rank=tp_rank,
)
linear_1d.process_group = process_group
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
......
......@@ -5,6 +5,7 @@ import torch
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
......@@ -16,6 +17,7 @@ if HAS_TRITON:
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
"""
if triton.__version__ < "2.1.0":
@triton.jit
def _context_flash_attention_kernel(
Q,
......@@ -131,29 +133,47 @@ if HAS_TRITON:
m_i = m_i_new
off_o = (
(cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
(cur_batch_start_index + offs_m[:, None]) * stride_obs
+ cur_head * stride_oh
+ offs_d[None, :] * stride_od
)
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
return
else:
# this function is modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L11
@triton.jit
def _context_flash_attention_kernel_2(
Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen,
Out,
kv_group_num,
stride_qbs, stride_qh, stride_qd,
stride_kbs, stride_kh, stride_kd,
stride_vbs, stride_vh, stride_vd,
stride_obs, stride_oh, stride_od,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
Q,
K,
V,
sm_scale,
Alibi,
B_Start_Loc,
B_Seqlen,
Out,
kv_group_num,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
if kv_group_num is not None:
cur_kv_head = cur_head // kv_group_num
......@@ -166,7 +186,11 @@ if HAS_TRITON:
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
+ cur_head * stride_qh
+ offs_d[None, :] * stride_qd
)
if kv_group_num is None or kv_group_num == 1:
off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
......@@ -191,8 +215,11 @@ if HAS_TRITON:
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0)
k = tl.load(
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
other=0.0,
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
......@@ -220,8 +247,11 @@ if HAS_TRITON:
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)
v = tl.load(
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
other=0.0,
)
p = p.to(v.dtype)
acc += tl.dot(p, v)
......@@ -229,7 +259,11 @@ if HAS_TRITON:
l_i = l_i_new
m_i = m_i_new
# initialize pointers to output
off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
+ cur_head * stride_oh
+ offs_d[None, :] * stride_od
)
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
return
......@@ -249,7 +283,7 @@ if HAS_TRITON:
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
num_warps = 4 if Lk <= 64 else 8
if triton.__version__ < "2.1.0":
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
_context_flash_attention_kernel[grid](
......@@ -286,20 +320,26 @@ if HAS_TRITON:
)
else:
_context_flash_attention_kernel_2[grid](
q, k, v, sm_scale, alibi, b_start_loc, b_seq_len,
q,
k,
v,
sm_scale,
alibi,
b_start_loc,
b_seq_len,
o,
None,
q.stride(0),
q.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(0),
o.stride(1),
o.stride(2),
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
......@@ -307,7 +347,7 @@ if HAS_TRITON:
num_warps=num_warps,
num_stages=1,
)
return
@torch.no_grad()
......@@ -327,7 +367,7 @@ if HAS_TRITON:
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
# num_warps = 4
if triton.__version__ < "2.1.0":
_context_flash_attention_kernel[grid](
q,
......@@ -337,7 +377,7 @@ if HAS_TRITON:
b_start_loc,
b_seq_len,
tmp,
None,
None,
o,
q.stride(0),
q.stride(1),
......@@ -362,32 +402,33 @@ if HAS_TRITON:
)
else:
kv_group_num = q.shape[1] // k.shape[1]
_context_flash_attention_kernel_2[grid](
q,
k,
v,
sm_scale,
_context_flash_attention_kernel_2[grid](
q,
k,
v,
sm_scale,
None,
b_start_loc,
b_start_loc,
b_seq_len,
o,
kv_group_num,
q.stride(0),
q.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(0),
o.stride(1),
o.stride(2),
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,)
return
\ No newline at end of file
num_stages=1,
)
return
# adepted from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8/lightllm/models/llama/triton_kernel/flash_decoding.py
import torch
try:
from lightllm.models.llama.triton_kernel.flash_decoding_stage1 import flash_decode_stage1
from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2
HAS_LIGHTLLM_KERNEL = True
except:
print("install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8")
......@@ -10,41 +12,36 @@ except:
if HAS_LIGHTLLM_KERNEL:
def token_flash_decoding(q, o_tensor, infer_state, q_head_num, head_dim, cache_k, cache_v):
BLOCK_SEQ = 256
batch_size = infer_state.batch_size
max_len_in_batch = infer_state.max_len_in_batch
calcu_shape1 = (batch_size, q_head_num, head_dim)
if getattr(infer_state, 'mid_o', None) is None:
infer_state.mid_o = torch.empty([batch_size,
q_head_num,
max_len_in_batch // BLOCK_SEQ + 1,
head_dim],
dtype=torch.float32,
device="cuda")
infer_state.mid_o_logexpsum = torch.empty([batch_size,
q_head_num,
max_len_in_batch // BLOCK_SEQ + 1],
dtype=torch.float32,
device="cuda")
if getattr(infer_state, "mid_o", None) is None:
infer_state.mid_o = torch.empty(
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim],
dtype=torch.float32,
device="cuda",
)
infer_state.mid_o_logexpsum = torch.empty(
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda"
)
mid_o = infer_state.mid_o
mid_o_logexpsum = infer_state.mid_o_logexpsum
flash_decode_stage1(q.view(calcu_shape1),
cache_k,
cache_v,
infer_state.block_loc,
infer_state.seq_len,
infer_state.max_len_in_batch,
mid_o,
mid_o_logexpsum,
BLOCK_SEQ)
flash_decode_stage2(mid_o,
mid_o_logexpsum,
infer_state.seq_len,
o_tensor.view(calcu_shape1),
BLOCK_SEQ)
flash_decode_stage1(
q.view(calcu_shape1),
cache_k,
cache_v,
infer_state.block_loc,
infer_state.seq_len,
infer_state.max_len_in_batch,
mid_o,
mid_o_logexpsum,
BLOCK_SEQ,
)
flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ)
......@@ -8,6 +8,7 @@ from torch.cuda.amp import custom_bwd, custom_fwd
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
......@@ -26,8 +27,8 @@ if HAS_TRITON:
X_GATE2,
X_UP,
Y,
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
BLOCK_SIZE: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
......@@ -41,9 +42,9 @@ if HAS_TRITON:
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.)
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.)
x_up = tl.load(X_UP + cols, mask=mask, other=0.)
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.0)
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.0)
x_up = tl.load(X_UP + cols, mask=mask, other=0.0)
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
y = x_gate1 * x_gate2 * x_gate2_sigmoid * x_up
# Write output
......@@ -58,8 +59,8 @@ if HAS_TRITON:
X_GATE2_GRAD,
X_UP_GRAD,
Y_GRAD,
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
BLOCK_SIZE: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
......@@ -76,10 +77,10 @@ if HAS_TRITON:
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.)
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.)
x_up = tl.load(X_UP + cols, mask=mask, other=0.)
y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.)
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.0)
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.0)
x_up = tl.load(X_UP + cols, mask=mask, other=0.0)
y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.0)
# forward: y = x_gate1 * x_gate2 * tl.sigmoid(x_gate2) * x_up
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
......@@ -147,14 +148,9 @@ if HAS_TRITON:
# restore setting
ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps = M, N, BLOCK_SIZE, num_warps
# enqueue kernel
_llama_act_combine_forward[(M,)](x_gate1,
x_gate2,
x_up,
y,
x_up.stride(-2),
N,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps)
_llama_act_combine_forward[(M,)](
x_gate1, x_gate2, x_up, y, x_up.stride(-2), N, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
)
return y
@staticmethod
......@@ -166,20 +162,25 @@ if HAS_TRITON:
# init grad
y_grad = grad_outputs[0]
x_gate1_grad, x_gate2_grad, x_up_grad = torch.empty_like(x_gate1), torch.empty_like(
x_gate2), torch.empty_like(x_up)
x_gate1_grad, x_gate2_grad, x_up_grad = (
torch.empty_like(x_gate1),
torch.empty_like(x_gate2),
torch.empty_like(x_up),
)
# enqueue kernel
_llama_act_combine_backward[(M,)](x_gate1,
x_gate2,
x_up,
x_gate1_grad,
x_gate2_grad,
x_up_grad,
y_grad,
x_up.stride(-2),
N,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps)
_llama_act_combine_backward[(M,)](
x_gate1,
x_gate2,
x_up,
x_gate1_grad,
x_gate2_grad,
x_up_grad,
y_grad,
x_up.stride(-2),
N,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
x_gate_grad = torch.cat([x_gate1_grad, x_gate2_grad], dim=-1)
return x_gate_grad, x_up_grad, None, None
......@@ -13,10 +13,18 @@ except ImportError:
print("please install triton from https://github.com/openai/triton")
try:
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fwd2
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd
from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_bloom_token_att_fwd
from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import (
token_att_fwd as lightllm_bloom_token_att_fwd,
)
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import (
token_att_fwd as lightllm_llama_token_att_fwd,
)
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import (
token_att_fwd2 as lightllm_llama_token_att_fwd2,
)
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import (
token_softmax_fwd as lightllm_llama_token_softmax_fwd,
)
HAS_TRITON_TOKEN_ATTENTION = True
except ImportError:
......@@ -205,9 +213,7 @@ class Llama2TokenAttentionForwards:
if triton.__version__ == "2.0.0":
prob = torch.empty_like(att_m_tensor)
lightllm_llama_token_softmax_fwd(
att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch
)
lightllm_llama_token_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch)
att_m_tensor = None
lightllm_llama_token_att_fwd2(
......
......@@ -8,7 +8,9 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecode
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
from ._utils import copy_kv_to_mem_cache
try:
from lightllm.models.llama.triton_kernel.context_flashattention_nopad import (
context_attention_fwd as lightllm_llama_context_attention_fwd,
......@@ -90,7 +92,7 @@ def llama_triton_token_attention(query_states, attn_output, infer_state, num_key
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
Llama2TokenAttentionForwards.token_attn(
query_states,
......
from .attn import AttnMaskType, ColoAttention
from ._operation import all_to_all_comm
from .attn import AttnMaskType, ColoAttention
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row
......
......@@ -2,13 +2,13 @@ from .api import (
compute_global_numel,
customized_distributed_tensor_to_param,
distribute_tensor,
init_as_dtensor,
distribute_tensor_with_customization,
init_tensor_as_customization_distributed,
get_device_mesh,
get_global_shape,
get_layout,
get_sharding_spec,
init_as_dtensor,
init_tensor_as_customization_distributed,
is_customized_distributed_tensor,
is_distributed_tensor,
is_sharded,
......
......@@ -128,7 +128,10 @@ def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_sp
return sharded_tensor
def init_as_dtensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size) -> torch.Tensor:
def init_as_dtensor(
tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size
) -> torch.Tensor:
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
......@@ -140,6 +143,7 @@ def init_as_dtensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec
return tensor
def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None:
"""
Convert the layout of the tensor from source_spec to target_spec.
......@@ -468,7 +472,6 @@ def init_tensor_as_customization_distributed(tensor: torch.Tensor, shard_fn, gat
assert callable(gather_fn), "The gather_fn must be callable."
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
# set the shard_fn and gather_fn as attributes of the distributed tensor
tensor.shard_fn = shard_fn
tensor.gather_fn = gather_fn
......
......@@ -190,6 +190,7 @@ def calculate_global_norm_from_list(norm_list):
total_norm += norm**2.0
return math.sqrt(total_norm)
def sync_tensor(flat_tensor, tensor_list):
"""
Synchronize the flattened tensor and unflattened tensor list. When
......
......@@ -220,7 +220,7 @@ model, optimizer, _criterion, train_dataloader, lr_scheduler = booster.boost(
)
```
## 使用混合并行训练 ViT
最后就可以使用混合并行策略来训练模型了。我们先定义一个训练函数,描述训练过程。需要注意的是,如果使用了管道并行策略,需要调用`booster.execute_pipeline`来执行模型的训练,它会调用`scheduler`管理模型的前后向操作。
最后就可以使用混合并行策略来训练模型了。我们先定义一个训练函数,描述训练过程。需要注意的是,如果使用了管道并行策略,需要调用`booster.execute_pipeline`来执行模型的训练,它会调用`scheduler`管理模型的前后向操作。
```python
def run_forward_backward(
model: nn.Module,
......
......@@ -119,9 +119,7 @@ def main():
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
# run pipeline forward backward
batch = iter([batch])
outputs = booster.execute_pipeline(
batch, model, criterion, optimizer, return_loss=True
)
outputs = booster.execute_pipeline(batch, model, criterion, optimizer, return_loss=True)
else:
outputs = model(**batch)
loss = criterion(outputs, None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment