Commit 6a583c2f authored by chenych's avatar chenych
Browse files

update dtk to 24.04.1 and modify README

parent 7d576a9a
...@@ -52,9 +52,6 @@ class CacheEngine: ...@@ -52,9 +52,6 @@ class CacheEngine:
# Initialize the cache. # Initialize the cache.
self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda") self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda")
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu") self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
if self.model_config.hf_config.model_type == 'yuan':
self.lf_gpu_cache = self._allocate_lf_cache(self.cache_config.max_num_seqs, "cuda")
self.lf_cpu_cache = self._allocate_lf_cache(self.cache_config.max_num_seqs, "cpu")
def _allocate_kv_cache( def _allocate_kv_cache(
self, self,
......
...@@ -51,21 +51,41 @@ class PreparePromptMetadata(NamedTuple): ...@@ -51,21 +51,41 @@ class PreparePromptMetadata(NamedTuple):
lora_requests: Set[LoRARequest] lora_requests: Set[LoRARequest]
multi_modal_input: Optional[torch.Tensor] multi_modal_input: Optional[torch.Tensor]
slot_mapping: List[int] slot_mapping: List[int]
lf1_caches: List[List[torch.Tensor]]
lf2_caches: List[List[torch.Tensor]]
@classmethod @classmethod
def empty(cls): def empty(cls, lf1_caches=None, lf2_caches=None):
return PreparePromptMetadata( if lf1_caches == None and lf2_caches == None:
input_tokens=[], return PreparePromptMetadata(
input_positions=[], input_tokens=[],
attn_metadata=None, input_positions=[],
prompt_lens=[], attn_metadata=None,
subquery_lens=[], prompt_lens=[],
lora_index_mapping=[], subquery_lens=[],
lora_prompt_mapping=[], lora_index_mapping=[],
lora_requests=set(), lora_prompt_mapping=[],
multi_modal_input=None, lora_requests=set(),
slot_mapping=[], multi_modal_input=None,
) slot_mapping=[],
lf1_caches=[[]],
lf2_caches=[[]]
)
else:
return PreparePromptMetadata(
input_tokens=[],
input_positions=[],
attn_metadata=None,
prompt_lens=[],
subquery_lens=[],
lora_index_mapping=[],
lora_prompt_mapping=[],
lora_requests=set(),
multi_modal_input=None,
slot_mapping=[],
lf1_caches=lf1_caches,
lf2_caches=lf2_caches,
)
class PrepareDecodeMetadata(NamedTuple): class PrepareDecodeMetadata(NamedTuple):
...@@ -76,18 +96,35 @@ class PrepareDecodeMetadata(NamedTuple): ...@@ -76,18 +96,35 @@ class PrepareDecodeMetadata(NamedTuple):
lora_prompt_mapping: List[int] lora_prompt_mapping: List[int]
lora_requests: Set[LoRARequest] lora_requests: Set[LoRARequest]
slot_mapping: List[int] slot_mapping: List[int]
lf1_caches: List[List[torch.Tensor]]
lf2_caches: List[List[torch.Tensor]]
@classmethod @classmethod
def empty(cls): def empty(cls, lf1_caches=None, lf2_caches=None):
return PrepareDecodeMetadata( if lf1_caches == None or lf2_caches == None:
input_tokens=[], return PrepareDecodeMetadata(
input_positions=[], input_tokens=[],
attn_metadata=None, input_positions=[],
lora_index_mapping=[], attn_metadata=None,
lora_prompt_mapping=[], lora_index_mapping=[],
lora_requests=set(), lora_prompt_mapping=[],
slot_mapping=[], lora_requests=set(),
) slot_mapping=[],
lf1_caches=[[]],
lf2_caches=[[]],
)
else:
return PrepareDecodeMetadata(
input_tokens=[],
input_positions=[],
attn_metadata=None,
lora_index_mapping=[],
lora_prompt_mapping=[],
lora_requests=set(),
slot_mapping=[],
lf1_caches=lf1_caches,
lf2_caches=lf2_caches,
)
# How batches are constructed. # How batches are constructed.
...@@ -152,6 +189,10 @@ class ModelRunner: ...@@ -152,6 +189,10 @@ class ModelRunner:
self.attn_backend = get_attn_backend( self.attn_backend = get_attn_backend(
self.model_config.dtype if model_config is not None else None) self.model_config.dtype if model_config is not None else None)
self.num_layers = model_config.get_num_layers(parallel_config)
self.total_num_heads = model_config.hf_config.num_attention_heads
self.head_size = model_config.get_head_size()
self.hidden_size = self.head_size * self.total_num_heads
def load_model(self) -> None: def load_model(self) -> None:
with CudaMemoryProfiler() as m: with CudaMemoryProfiler() as m:
...@@ -232,9 +273,13 @@ class ModelRunner: ...@@ -232,9 +273,13 @@ class ModelRunner:
subquery_lens: List[int] = [] subquery_lens: List[int] = []
prefix_block_tables: List[List[int]] = [] prefix_block_tables: List[List[int]] = []
multi_modal_input_list: List[torch.Tensor] = [] multi_modal_input_list: List[torch.Tensor] = []
lf1_caches: List[List[torch.Tensor]] = [[] for i in range(self.num_layers)]
lf2_caches: List[List[torch.Tensor]] = [[] for i in range(self.num_layers)]
if len(seq_group_metadata_list) == 0: if len(seq_group_metadata_list) == 0:
return PreparePromptMetadata.empty() if self.model_config.hf_config.model_type == 'yuan':
return PreparePromptMetadata.empty(lf1_caches=lf1_caches, lf2_caches=lf2_caches)
else:
return PreparePromptMetadata.empty()
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt assert seq_group_metadata.is_prompt
...@@ -263,6 +308,8 @@ class ModelRunner: ...@@ -263,6 +308,8 @@ class ModelRunner:
prompt_lens.append(prompt_len) prompt_lens.append(prompt_len)
# NOTE: This only works for oooooooxxx style attention. # NOTE: This only works for oooooooxxx style attention.
# zhaoxd 只有RUNNING阶段的seq_group,computed_block_nums != None
# zhaoxd 否则prefix_block_tables只有空列表
if computed_block_nums is not None and len( if computed_block_nums is not None and len(
computed_block_nums) > 0 and self.sliding_window is None: computed_block_nums) > 0 and self.sliding_window is None:
# Prefix is not supported with sliding_window # Prefix is not supported with sliding_window
...@@ -306,13 +353,20 @@ class ModelRunner: ...@@ -306,13 +353,20 @@ class ModelRunner:
multi_modal_input_list.append( multi_modal_input_list.append(
seq_group_metadata.multi_modal_data.data) seq_group_metadata.multi_modal_data.data)
if self.model_config.hf_config.model_type == 'yuan':
for i in range(self.num_layers):
lf1_caches[i].extend(seq_group_metadata.lf1_caches[i])
lf2_caches[i].extend(seq_group_metadata.lf2_caches[i])
if seq_group_metadata.block_tables is None: if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized # During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping. # yet. In this case, we just use a dummy slot mapping.
# zhaoxd 在profiling阶段,block tables不会被初始化,只需要伪造一个slot mapping
slot_mapping.extend([_PAD_SLOT_ID] * prompt_len) slot_mapping.extend([_PAD_SLOT_ID] * prompt_len)
continue continue
# Compute the slot mapping. # Compute the slot mapping.
# zhaoxd block_table 当前seq group对应的block tables
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, prompt_len - sliding_window). # where start_idx is max(0, prompt_len - sliding_window).
...@@ -325,12 +379,12 @@ class ModelRunner: ...@@ -325,12 +379,12 @@ class ModelRunner:
"Prefix caching is currently not supported with " "Prefix caching is currently not supported with "
"sliding window attention") "sliding window attention")
start_idx = max(0, prompt_len - self.sliding_window) start_idx = max(0, prompt_len - self.sliding_window)
# zhaoxd 正常情况下computed_len=0, prefill_end=len(prompt), start_idx=0(如果开启sliding_window,从sliding_window的位置开始)
for i in range(computed_len, prefill_end): for i in range(computed_len, prefill_end):
if i < start_idx: if i < start_idx:
slot_mapping.append(_PAD_SLOT_ID) slot_mapping.append(_PAD_SLOT_ID)
continue continue
# zhaoxd slot是在这一组block_table中的offset
block_number = block_table[i // self.block_size] block_number = block_table[i // self.block_size]
block_offset = i % self.block_size block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
...@@ -399,7 +453,7 @@ class ModelRunner: ...@@ -399,7 +453,7 @@ class ModelRunner:
subquery_start_loc=subquery_start_loc, subquery_start_loc=subquery_start_loc,
seq_start_loc=seq_start_loc, seq_start_loc=seq_start_loc,
context_lens=context_lens_tensor, context_lens=context_lens_tensor,
block_tables=block_tables, block_tables=block_tables, # zhaoxd 对应的prefix的block_tables, 与seq_group分配的block_tables无关
use_cuda_graph=False, use_cuda_graph=False,
) )
return PreparePromptMetadata( return PreparePromptMetadata(
...@@ -413,6 +467,8 @@ class ModelRunner: ...@@ -413,6 +467,8 @@ class ModelRunner:
lora_requests=lora_requests, lora_requests=lora_requests,
multi_modal_input=multi_modal_input, multi_modal_input=multi_modal_input,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
lf1_caches=lf1_caches,
lf2_caches=lf2_caches,
) )
def _prepare_decode( def _prepare_decode(
...@@ -427,9 +483,14 @@ class ModelRunner: ...@@ -427,9 +483,14 @@ class ModelRunner:
lora_index_mapping: List[int] = [] lora_index_mapping: List[int] = []
lora_prompt_mapping: List[int] = [] lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set() lora_requests: Set[LoRARequest] = set()
lf1_caches = [[] for i in range(self.num_layers)]
lf2_caches = [[] for i in range(self.num_layers)]
if len(seq_group_metadata_list) == 0: if len(seq_group_metadata_list) == 0:
return PrepareDecodeMetadata.empty() if self.model_config.hf_config.model_type == 'yuan':
return PrepareDecodeMetadata.empty(lf1_caches, lf2_caches)
else:
return PrepareDecodeMetadata.empty()
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt assert not seq_group_metadata.is_prompt
...@@ -441,6 +502,11 @@ class ModelRunner: ...@@ -441,6 +502,11 @@ class ModelRunner:
if lora_id > 0: if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request) lora_requests.add(seq_group_metadata.lora_request)
if self.model_config.hf_config.model_type == 'yuan':
for i in range(self.num_layers):
lf1_caches[i].extend(seq_group_metadata.lf1_caches[i])
lf2_caches[i].extend(seq_group_metadata.lf2_caches[i])
for seq_id in seq_ids: for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id] seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id() generation_token = seq_data.get_last_token_id()
...@@ -539,6 +605,8 @@ class ModelRunner: ...@@ -539,6 +605,8 @@ class ModelRunner:
lora_prompt_mapping=lora_prompt_mapping, lora_prompt_mapping=lora_prompt_mapping,
lora_requests=lora_requests, lora_requests=lora_requests,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
lf1_caches=lf1_caches,
lf2_caches=lf2_caches,
) )
def _prepare_sample( def _prepare_sample(
...@@ -652,7 +720,8 @@ class ModelRunner: ...@@ -652,7 +720,8 @@ class ModelRunner:
prefill_reqs.append(seq_group_meta) prefill_reqs.append(seq_group_meta)
else: else:
decode_reqs.append(seq_group_meta) decode_reqs.append(seq_group_meta)
lf1_caches = [None for _ in range(self.num_layers)]
lf2_caches = [None for _ in range(self.num_layers)]
# Prepare input tensors. # Prepare input tensors.
( (
input_tokens, input_tokens,
...@@ -665,6 +734,8 @@ class ModelRunner: ...@@ -665,6 +734,8 @@ class ModelRunner:
lora_requests, lora_requests,
multi_modal_input, multi_modal_input,
slot_mapping, slot_mapping,
prompt_lf1_caches,
prompt_lf2_caches,
) = self._prepare_prompt(prefill_reqs) ) = self._prepare_prompt(prefill_reqs)
( (
decode_input_tokens, decode_input_tokens,
...@@ -674,7 +745,22 @@ class ModelRunner: ...@@ -674,7 +745,22 @@ class ModelRunner:
decode_lora_prompt_mapping, decode_lora_prompt_mapping,
decode_lora_requests, decode_lora_requests,
decode_slot_mapping, decode_slot_mapping,
decode_lf1_caches,
decode_lf2_caches,
) = self._prepare_decode(decode_reqs) ) = self._prepare_decode(decode_reqs)
if self.model_config.hf_config.model_type == 'yuan':
for i in range(self.num_layers):
if len(prompt_lf1_caches[i])>0 and len(decode_lf1_caches[i])>0:
lf1_caches[i] = torch.cat(prompt_lf1_caches[i] + decode_lf1_caches[i], dim=0)
lf2_caches[i] = torch.cat(prompt_lf2_caches[i] + decode_lf2_caches[i], dim=0)
elif len(prompt_lf1_caches[i])>0:
lf1_caches[i] = torch.cat(prompt_lf1_caches[i], dim=0)
lf2_caches[i] = torch.cat(prompt_lf2_caches[i], dim=0)
elif len(decode_lf1_caches[i]) > 0:
lf1_caches[i] = torch.cat(decode_lf1_caches[i], dim=0)
lf2_caches[i] = torch.cat(decode_lf2_caches[i], dim=0)
sampling_metadata = self._prepare_sample(seq_group_metadata_list, sampling_metadata = self._prepare_sample(seq_group_metadata_list,
prompt_lens, prompt_lens,
subquery_lens) subquery_lens)
...@@ -791,6 +877,14 @@ class ModelRunner: ...@@ -791,6 +877,14 @@ class ModelRunner:
metadata_dict = broadcast_tensor_dict(src=0) metadata_dict = broadcast_tensor_dict(src=0)
decode_attn_metadata = self.attn_backend.make_metadata( decode_attn_metadata = self.attn_backend.make_metadata(
**metadata_dict) **metadata_dict)
bsz = num_prefills + num_decode_tokens
lf1_caches = [torch.zeros((bsz, self.hidden_size, 1, 1), dtype=self.model_config.dtype, device=self.device) for i in range(self.num_layers)]
lf2_caches = [torch.zeros((bsz, self.hidden_size // 2, 1, 1), dtype=self.model_config.dtype, device=self.device) for i in range(self.num_layers)]
prefill_reqs = []
decode_reqs = []
# for b in range(bsz):
# prefill_reqs.append(prefill_attn_metadata)
# decode_reqs.append(decode_attn_metadata)
attn_metadata = AttentionMetadata( attn_metadata = AttentionMetadata(
num_prefills=num_prefills, num_prefills=num_prefills,
...@@ -801,20 +895,37 @@ class ModelRunner: ...@@ -801,20 +895,37 @@ class ModelRunner:
decode_metadata=decode_attn_metadata, decode_metadata=decode_attn_metadata,
kv_cache_dtype=self.kv_cache_dtype, kv_cache_dtype=self.kv_cache_dtype,
) )
return (input_tokens, input_positions, attn_metadata, return (input_tokens, input_positions, attn_metadata,
sampling_metadata, lora_requests, lora_mapping, sampling_metadata, lora_requests, lora_mapping,
multi_modal_input) multi_modal_input, lf1_caches, lf2_caches, prefill_reqs, decode_reqs)
def update_lf_caches(self, prefill_reqs: List[PreparePromptMetadata], decode_reqs: List[PrepareDecodeMetadata], lf1_caches: List[torch.Tensor], lf2_caches: List[torch.Tensor]):
start_idx = 0
if len(prefill_reqs) > 0 :
for seq_group_metadata in prefill_reqs:
num_seqs = len(seq_group_metadata.seq_data.keys())
for i in range(self.num_layers):
for j in range(len(seq_group_metadata.lf1_caches[i])):
seq_group_metadata.lf1_caches[i][j].copy_(lf1_caches[i][start_idx+j:start_idx+j+1].detach().clone())
seq_group_metadata.lf2_caches[i][j].copy_(lf2_caches[i][start_idx+j:start_idx+j+1].detach().clone())
start_idx += num_seqs
if len(decode_reqs) > 0 :
for seq_group_metadata in decode_reqs:
num_seqs = len(seq_group_metadata.seq_data.keys())
for i in range(self.num_layers):
for j in range(len(seq_group_metadata.lf1_caches[i])):
seq_group_metadata.lf1_caches[i][j].copy_(lf1_caches[i][start_idx+j:start_idx+j+1].detach().clone())
seq_group_metadata.lf2_caches[i][j].copy_(lf2_caches[i][start_idx+j:start_idx+j+1].detach().clone())
start_idx += num_seqs
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
lf_caches: List[Tuple[torch.Tensor, torch.Tensor]] = None update_lf_caches: bool = False,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata, (input_tokens, input_positions, attn_metadata, sampling_metadata,
lora_requests, lora_mapping, multi_modal_input lora_requests, lora_mapping, multi_modal_input, lf1_caches, lf2_caches, prefill_reqs, decode_reqs
) = self.prepare_input_tensors(seq_group_metadata_list) ) = self.prepare_input_tensors(seq_group_metadata_list)
if self.lora_config: if self.lora_config:
...@@ -834,13 +945,15 @@ class ModelRunner: ...@@ -834,13 +945,15 @@ class ModelRunner:
"kv_caches": kv_caches, "kv_caches": kv_caches,
"attn_metadata": attn_metadata, "attn_metadata": attn_metadata,
} }
if lf_caches != None:
batch_size = attn_metadata.num_prefills + attn_metadata.num_decode_tokens if self.model_config.hf_config.model_type == 'yuan':
execute_model_kwargs.update({'lf_caches': [(lf1_cache[:batch_size], lf2_cache[:batch_size]) for (lf1_cache, lf2_cache) in lf_caches] if lf_caches[0] != (None, None) else lf_caches}) execute_model_kwargs.update({'lf1_caches': lf1_caches})
execute_model_kwargs.update({'lf2_caches': lf2_caches})
if self.vision_language_config: if self.vision_language_config:
execute_model_kwargs.update({"image_input": multi_modal_input}) execute_model_kwargs.update({"image_input": multi_modal_input})
hidden_states = model_executable(**execute_model_kwargs) hidden_states = model_executable(**execute_model_kwargs)
if self.model_config.hf_config.model_type == 'yuan' and update_lf_caches and (len(prefill_reqs) +len(decode_reqs)) > 0:
self.update_lf_caches(prefill_reqs, decode_reqs, lf1_caches, lf2_caches)
# Compute the logits. # Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata) logits = self.model.compute_logits(hidden_states, sampling_metadata)
# Only perform sampling in the driver worker. # Only perform sampling in the driver worker.
...@@ -854,7 +967,7 @@ class ModelRunner: ...@@ -854,7 +967,7 @@ class ModelRunner:
return output return output
@torch.inference_mode() @torch.inference_mode()
def profile_run(self, use_lf_caches=False) -> None: def profile_run(self) -> None:
# Enable top-k sampling to reflect the accurate memory usage. # Enable top-k sampling to reflect the accurate memory usage.
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
...@@ -901,6 +1014,8 @@ class ModelRunner: ...@@ -901,6 +1014,8 @@ class ModelRunner:
(group_id < max_num_batched_tokens % max_num_seqs)) (group_id < max_num_batched_tokens % max_num_seqs))
seq_data, fake_multi_modal_input = _prepare_fake_inputs( seq_data, fake_multi_modal_input = _prepare_fake_inputs(
seq_len, self.vision_language_config) seq_len, self.vision_language_config)
lf1_caches = [[torch.zeros((1, self.hidden_size, 1, 1), dtype=self.model_config.dtype, device=self.device)] for i in range(self.num_layers)]
lf2_caches = [[torch.zeros((1, self.hidden_size // 2, 1, 1), dtype=self.model_config.dtype, device=self.device)] for i in range(self.num_layers)]
seq = SequenceGroupMetadata( seq = SequenceGroupMetadata(
request_id=str(group_id), request_id=str(group_id),
is_prompt=True, is_prompt=True,
...@@ -910,15 +1025,16 @@ class ModelRunner: ...@@ -910,15 +1025,16 @@ class ModelRunner:
lora_request=dummy_lora_requests_per_seq[group_id] lora_request=dummy_lora_requests_per_seq[group_id]
if dummy_lora_requests_per_seq else None, if dummy_lora_requests_per_seq else None,
multi_modal_data=fake_multi_modal_input, multi_modal_data=fake_multi_modal_input,
lf1_caches=lf1_caches,
lf2_caches=lf2_caches,
) )
seqs.append(seq) seqs.append(seq)
# Run the model with the dummy inputs. # Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config) num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers kv_caches = [None] * num_layers
lf_caches = [(None, None)] * num_layers if use_lf_caches else None
self.execute_model(seqs, kv_caches, lf_caches) self.execute_model(seqs, kv_caches)
torch.cuda.synchronize() torch.cuda.synchronize()
return return
...@@ -949,7 +1065,7 @@ class ModelRunner: ...@@ -949,7 +1065,7 @@ class ModelRunner:
return self.lora_manager.list_loras() return self.lora_manager.list_loras()
@torch.inference_mode() @torch.inference_mode()
def capture_model(self, kv_caches: List[torch.Tensor], lf_caches: List[LFCache] = None) -> None: def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
"""Cuda graph capture a model. """Cuda graph capture a model.
Note that CUDA graph's performance gain is negligible if number Note that CUDA graph's performance gain is negligible if number
...@@ -1004,6 +1120,8 @@ class ModelRunner: ...@@ -1004,6 +1120,8 @@ class ModelRunner:
# memory usage of CUDA graph. # memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list): for batch_size in reversed(batch_size_capture_list):
# Create dummy attn_metadata. # Create dummy attn_metadata.
lf1_caches = [torch.zeros((batch_size, self.hidden_size, 1, 1), dtype=self.model_config.dtype, device=self.device) for i in range(self.num_layers)]
lf2_caches = [torch.zeros((batch_size, self.hidden_size//2, 1, 1), dtype=self.model_config.dtype, device=self.device) for i in range(self.num_layers)]
decode_metadata = self.attn_backend.make_metadata( decode_metadata = self.attn_backend.make_metadata(
is_prompt=False, is_prompt=False,
prompt_lens=None, prompt_lens=None,
...@@ -1035,23 +1153,16 @@ class ModelRunner: ...@@ -1035,23 +1153,16 @@ class ModelRunner:
self.set_active_loras(set(), lora_mapping) self.set_active_loras(set(), lora_mapping)
graph_runner = CUDAGraphRunner(self.model) graph_runner = CUDAGraphRunner(self.model)
if lf_caches != None: graph_runner.capture(
graph_runner.capture( input_tokens[:batch_size],
input_tokens[:batch_size], input_positions[:batch_size],
input_positions[:batch_size], kv_caches,
kv_caches, lf1_caches,
[(lf1_cache[:batch_size], lf2_cache[:batch_size]) for (lf1_cache, lf2_cache) in lf_caches], lf2_caches,
attn_metadata, attn_metadata,
memory_pool=self.graph_memory_pool, memory_pool=self.graph_memory_pool,
) model_type=self.model_config.hf_config.model_type
else: )
graph_runner.capture(
input_tokens[:batch_size],
input_positions[:batch_size],
kv_caches,
attn_metadata,
memory_pool=self.graph_memory_pool,
)
self.graph_memory_pool = graph_runner.graph.pool() self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[batch_size] = graph_runner self.graph_runners[batch_size] = graph_runner
...@@ -1088,9 +1199,11 @@ class CUDAGraphRunner: ...@@ -1088,9 +1199,11 @@ class CUDAGraphRunner:
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
lf_caches: List[LFCache], lf1_caches: List[torch.Tensor],
lf2_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
memory_pool, memory_pool,
model_type: str,
**kwargs, **kwargs,
) -> None: ) -> None:
assert self.graph is None assert self.graph is None
...@@ -1098,11 +1211,13 @@ class CUDAGraphRunner: ...@@ -1098,11 +1211,13 @@ class CUDAGraphRunner:
# This is to make sure that the captured graph does not include the # This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune). # kernel launches for initial benchmarking (e.g., Triton autotune).
with _maybe_pynccl(): with _maybe_pynccl():
if lf_caches == None: if model_type == 'yuan':
self.model( self.model(
input_ids, input_ids,
positions, positions,
kv_caches, kv_caches,
lf1_caches,
lf2_caches,
attn_metadata, attn_metadata,
**kwargs, **kwargs,
) )
...@@ -1111,7 +1226,6 @@ class CUDAGraphRunner: ...@@ -1111,7 +1226,6 @@ class CUDAGraphRunner:
input_ids, input_ids,
positions, positions,
kv_caches, kv_caches,
lf_caches,
attn_metadata, attn_metadata,
**kwargs, **kwargs,
) )
...@@ -1123,12 +1237,13 @@ class CUDAGraphRunner: ...@@ -1123,12 +1237,13 @@ class CUDAGraphRunner:
self.graph = torch.cuda.CUDAGraph() self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117 with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117
with _maybe_pynccl(): with _maybe_pynccl():
if lf_caches != None: if model_type=='yuan':
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
positions, positions,
kv_caches, kv_caches,
lf_caches, lf1_caches,
lf2_caches,
attn_metadata, attn_metadata,
**kwargs, **kwargs,
) )
...@@ -1147,12 +1262,12 @@ class CUDAGraphRunner: ...@@ -1147,12 +1262,12 @@ class CUDAGraphRunner:
"input_ids": input_ids, "input_ids": input_ids,
"positions": positions, "positions": positions,
"kv_caches": kv_caches, "kv_caches": kv_caches,
"lf1_caches": lf1_caches,
"lf2_caches": lf2_caches,
"slot_mapping": attn_metadata.slot_mapping, "slot_mapping": attn_metadata.slot_mapping,
"context_lens": attn_metadata.decode_metadata.context_lens, "context_lens": attn_metadata.decode_metadata.context_lens,
"block_tables": attn_metadata.decode_metadata.block_tables, "block_tables": attn_metadata.decode_metadata.block_tables,
} }
if lf_caches != None:
self.input_buffers.update({"lf_caches": lf_caches})
self.output_buffers = {"hidden_states": hidden_states} self.output_buffers = {"hidden_states": hidden_states}
return return
...@@ -1161,7 +1276,6 @@ class CUDAGraphRunner: ...@@ -1161,7 +1276,6 @@ class CUDAGraphRunner:
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
lf_caches: List[Tuple[torch.Tensor, torch.Tensor]],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
......
...@@ -131,10 +131,7 @@ class Worker(WorkerBase): ...@@ -131,10 +131,7 @@ class Worker(WorkerBase):
# Execute a forward pass with dummy inputs to profile the memory usage # Execute a forward pass with dummy inputs to profile the memory usage
# of the model. # of the model.
if self.model_config.hf_config.model_type == 'yuan': self.model_runner.profile_run()
self.model_runner.profile_run(use_lf_caches=True)
else:
self.model_runner.profile_run()
# Calculate the number of blocks that can be allocated with the # Calculate the number of blocks that can be allocated with the
# profiled peak memory. # profiled peak memory.
...@@ -182,12 +179,11 @@ class Worker(WorkerBase): ...@@ -182,12 +179,11 @@ class Worker(WorkerBase):
self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.cache_engine = CacheEngine(self.cache_config, self.model_config,
self.parallel_config) self.parallel_config)
self.gpu_cache = self.cache_engine.gpu_cache self.gpu_cache = self.cache_engine.gpu_cache
self.lf_gpu_cache = self.cache_engine.lf_gpu_cache if self.model_config.hf_config.model_type == 'yuan' else None
self.model_runner.set_block_size(self.cache_engine.block_size) self.model_runner.set_block_size(self.cache_engine.block_size)
def _warm_up_model(self) -> None: def _warm_up_model(self) -> None:
if not self.model_config.enforce_eager: if not self.model_config.enforce_eager:
self.model_runner.capture_model(self.gpu_cache, self.lf_gpu_cache) self.model_runner.capture_model(self.gpu_cache)
# Reset the seed to ensure that the random state is not affected by # Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling. # the model initialization and profiling.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
...@@ -243,7 +239,7 @@ class Worker(WorkerBase): ...@@ -243,7 +239,7 @@ class Worker(WorkerBase):
if num_seq_groups == 0: if num_seq_groups == 0:
return [] return []
output = self.model_runner.execute_model(seq_group_metadata_list, output = self.model_runner.execute_model(seq_group_metadata_list,
self.gpu_cache, self.lf_gpu_cache) self.gpu_cache, update_lf_caches=True)
# Worker only supports single-step execution. Wrap the output in a list # Worker only supports single-step execution. Wrap the output in a list
# to conform to interface. # to conform to interface.
......
import os from vllm import LLM, SamplingParams
import time import time
import argparse import os
from transformers import LlamaTokenizer from transformers import LlamaTokenizer
from vllm import LLM, SamplingParams
## params
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', default='', help='model path')
args = parser.parse_args()
model_path = args.model_path tokenizer = LlamaTokenizer.from_pretrained('/mnt/beegfs2/Yuan2-M32-HF/', add_eos_token=False, add_bos_token=False, eos_token='<eod>')
tokenizer = LlamaTokenizer.from_pretrained(model_path, add_eos_token=False, add_bos_token=False, eos_token='<eod>')
tokenizer.add_tokens(['<sep>', '<pad>', '<mask>', '<predict>', '<FIM_SUFFIX>', '<FIM_PREFIX>', '<FIM_MIDDLE>','<commit_before>','<commit_msg>','<commit_after>','<jupyter_start>','<jupyter_text>','<jupyter_code>','<jupyter_output>','<empty_output>'], special_tokens=True) tokenizer.add_tokens(['<sep>', '<pad>', '<mask>', '<predict>', '<FIM_SUFFIX>', '<FIM_PREFIX>', '<FIM_MIDDLE>','<commit_before>','<commit_msg>','<commit_after>','<jupyter_start>','<jupyter_text>','<jupyter_code>','<jupyter_output>','<empty_output>'], special_tokens=True)
prompts = ["写一篇春游作文"] prompts = ["写一篇春游作文"]
sampling_params = SamplingParams(max_tokens=300, temperature=1, top_p=0, top_k=1, min_p=0.0, length_penalty=1.0, repetition_penalty=1.0, stop="<eod>", ) sampling_params = SamplingParams(max_tokens=300, temperature=1, top_p=0, top_k=1, min_p=0.0, length_penalty=1.0, repetition_penalty=1.0, stop="<eod>", )
## init model llm = LLM(model="/mnt/beegfs2/Yuan2-M32-HF/", trust_remote_code=True, tensor_parallel_size=8, gpu_memory_utilization=0.8, disable_custom_all_reduce=True, max_num_seqs=1)
llm = LLM(model=model_path, trust_remote_code=True, tensor_parallel_size=8, gpu_memory_utilization=0.8, disable_custom_all_reduce=True, max_num_seqs=1)
## inference
start_time = time.time() start_time = time.time()
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
end_time = time.time() end_time = time.time()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment