Commit 6a583c2f authored by chenych's avatar chenych
Browse files

update dtk to 24.04.1 and modify README

parent 7d576a9a
......@@ -52,9 +52,6 @@ class CacheEngine:
# Initialize the cache.
self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda")
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
if self.model_config.hf_config.model_type == 'yuan':
self.lf_gpu_cache = self._allocate_lf_cache(self.cache_config.max_num_seqs, "cuda")
self.lf_cpu_cache = self._allocate_lf_cache(self.cache_config.max_num_seqs, "cpu")
def _allocate_kv_cache(
self,
......
......@@ -51,21 +51,41 @@ class PreparePromptMetadata(NamedTuple):
lora_requests: Set[LoRARequest]
multi_modal_input: Optional[torch.Tensor]
slot_mapping: List[int]
lf1_caches: List[List[torch.Tensor]]
lf2_caches: List[List[torch.Tensor]]
@classmethod
def empty(cls):
return PreparePromptMetadata(
input_tokens=[],
input_positions=[],
attn_metadata=None,
prompt_lens=[],
subquery_lens=[],
lora_index_mapping=[],
lora_prompt_mapping=[],
lora_requests=set(),
multi_modal_input=None,
slot_mapping=[],
)
def empty(cls, lf1_caches=None, lf2_caches=None):
if lf1_caches == None and lf2_caches == None:
return PreparePromptMetadata(
input_tokens=[],
input_positions=[],
attn_metadata=None,
prompt_lens=[],
subquery_lens=[],
lora_index_mapping=[],
lora_prompt_mapping=[],
lora_requests=set(),
multi_modal_input=None,
slot_mapping=[],
lf1_caches=[[]],
lf2_caches=[[]]
)
else:
return PreparePromptMetadata(
input_tokens=[],
input_positions=[],
attn_metadata=None,
prompt_lens=[],
subquery_lens=[],
lora_index_mapping=[],
lora_prompt_mapping=[],
lora_requests=set(),
multi_modal_input=None,
slot_mapping=[],
lf1_caches=lf1_caches,
lf2_caches=lf2_caches,
)
class PrepareDecodeMetadata(NamedTuple):
......@@ -76,18 +96,35 @@ class PrepareDecodeMetadata(NamedTuple):
lora_prompt_mapping: List[int]
lora_requests: Set[LoRARequest]
slot_mapping: List[int]
lf1_caches: List[List[torch.Tensor]]
lf2_caches: List[List[torch.Tensor]]
@classmethod
def empty(cls):
return PrepareDecodeMetadata(
input_tokens=[],
input_positions=[],
attn_metadata=None,
lora_index_mapping=[],
lora_prompt_mapping=[],
lora_requests=set(),
slot_mapping=[],
)
def empty(cls, lf1_caches=None, lf2_caches=None):
if lf1_caches == None or lf2_caches == None:
return PrepareDecodeMetadata(
input_tokens=[],
input_positions=[],
attn_metadata=None,
lora_index_mapping=[],
lora_prompt_mapping=[],
lora_requests=set(),
slot_mapping=[],
lf1_caches=[[]],
lf2_caches=[[]],
)
else:
return PrepareDecodeMetadata(
input_tokens=[],
input_positions=[],
attn_metadata=None,
lora_index_mapping=[],
lora_prompt_mapping=[],
lora_requests=set(),
slot_mapping=[],
lf1_caches=lf1_caches,
lf2_caches=lf2_caches,
)
# How batches are constructed.
......@@ -152,6 +189,10 @@ class ModelRunner:
self.attn_backend = get_attn_backend(
self.model_config.dtype if model_config is not None else None)
self.num_layers = model_config.get_num_layers(parallel_config)
self.total_num_heads = model_config.hf_config.num_attention_heads
self.head_size = model_config.get_head_size()
self.hidden_size = self.head_size * self.total_num_heads
def load_model(self) -> None:
with CudaMemoryProfiler() as m:
......@@ -232,9 +273,13 @@ class ModelRunner:
subquery_lens: List[int] = []
prefix_block_tables: List[List[int]] = []
multi_modal_input_list: List[torch.Tensor] = []
lf1_caches: List[List[torch.Tensor]] = [[] for i in range(self.num_layers)]
lf2_caches: List[List[torch.Tensor]] = [[] for i in range(self.num_layers)]
if len(seq_group_metadata_list) == 0:
return PreparePromptMetadata.empty()
if self.model_config.hf_config.model_type == 'yuan':
return PreparePromptMetadata.empty(lf1_caches=lf1_caches, lf2_caches=lf2_caches)
else:
return PreparePromptMetadata.empty()
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
......@@ -263,6 +308,8 @@ class ModelRunner:
prompt_lens.append(prompt_len)
# NOTE: This only works for oooooooxxx style attention.
# zhaoxd 只有RUNNING阶段的seq_group,computed_block_nums != None
# zhaoxd 否则prefix_block_tables只有空列表
if computed_block_nums is not None and len(
computed_block_nums) > 0 and self.sliding_window is None:
# Prefix is not supported with sliding_window
......@@ -306,13 +353,20 @@ class ModelRunner:
multi_modal_input_list.append(
seq_group_metadata.multi_modal_data.data)
if self.model_config.hf_config.model_type == 'yuan':
for i in range(self.num_layers):
lf1_caches[i].extend(seq_group_metadata.lf1_caches[i])
lf2_caches[i].extend(seq_group_metadata.lf2_caches[i])
if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
# zhaoxd 在profiling阶段,block tables不会被初始化,只需要伪造一个slot mapping
slot_mapping.extend([_PAD_SLOT_ID] * prompt_len)
continue
# Compute the slot mapping.
# zhaoxd block_table 当前seq group对应的block tables
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, prompt_len - sliding_window).
......@@ -325,12 +379,12 @@ class ModelRunner:
"Prefix caching is currently not supported with "
"sliding window attention")
start_idx = max(0, prompt_len - self.sliding_window)
# zhaoxd 正常情况下computed_len=0, prefill_end=len(prompt), start_idx=0(如果开启sliding_window,从sliding_window的位置开始)
for i in range(computed_len, prefill_end):
if i < start_idx:
slot_mapping.append(_PAD_SLOT_ID)
continue
# zhaoxd slot是在这一组block_table中的offset
block_number = block_table[i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
......@@ -399,7 +453,7 @@ class ModelRunner:
subquery_start_loc=subquery_start_loc,
seq_start_loc=seq_start_loc,
context_lens=context_lens_tensor,
block_tables=block_tables,
block_tables=block_tables, # zhaoxd 对应的prefix的block_tables, 与seq_group分配的block_tables无关
use_cuda_graph=False,
)
return PreparePromptMetadata(
......@@ -413,6 +467,8 @@ class ModelRunner:
lora_requests=lora_requests,
multi_modal_input=multi_modal_input,
slot_mapping=slot_mapping,
lf1_caches=lf1_caches,
lf2_caches=lf2_caches,
)
def _prepare_decode(
......@@ -427,9 +483,14 @@ class ModelRunner:
lora_index_mapping: List[int] = []
lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set()
lf1_caches = [[] for i in range(self.num_layers)]
lf2_caches = [[] for i in range(self.num_layers)]
if len(seq_group_metadata_list) == 0:
return PrepareDecodeMetadata.empty()
if self.model_config.hf_config.model_type == 'yuan':
return PrepareDecodeMetadata.empty(lf1_caches, lf2_caches)
else:
return PrepareDecodeMetadata.empty()
for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
......@@ -441,6 +502,11 @@ class ModelRunner:
if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request)
if self.model_config.hf_config.model_type == 'yuan':
for i in range(self.num_layers):
lf1_caches[i].extend(seq_group_metadata.lf1_caches[i])
lf2_caches[i].extend(seq_group_metadata.lf2_caches[i])
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
......@@ -539,6 +605,8 @@ class ModelRunner:
lora_prompt_mapping=lora_prompt_mapping,
lora_requests=lora_requests,
slot_mapping=slot_mapping,
lf1_caches=lf1_caches,
lf2_caches=lf2_caches,
)
def _prepare_sample(
......@@ -652,7 +720,8 @@ class ModelRunner:
prefill_reqs.append(seq_group_meta)
else:
decode_reqs.append(seq_group_meta)
lf1_caches = [None for _ in range(self.num_layers)]
lf2_caches = [None for _ in range(self.num_layers)]
# Prepare input tensors.
(
input_tokens,
......@@ -665,6 +734,8 @@ class ModelRunner:
lora_requests,
multi_modal_input,
slot_mapping,
prompt_lf1_caches,
prompt_lf2_caches,
) = self._prepare_prompt(prefill_reqs)
(
decode_input_tokens,
......@@ -674,7 +745,22 @@ class ModelRunner:
decode_lora_prompt_mapping,
decode_lora_requests,
decode_slot_mapping,
decode_lf1_caches,
decode_lf2_caches,
) = self._prepare_decode(decode_reqs)
if self.model_config.hf_config.model_type == 'yuan':
for i in range(self.num_layers):
if len(prompt_lf1_caches[i])>0 and len(decode_lf1_caches[i])>0:
lf1_caches[i] = torch.cat(prompt_lf1_caches[i] + decode_lf1_caches[i], dim=0)
lf2_caches[i] = torch.cat(prompt_lf2_caches[i] + decode_lf2_caches[i], dim=0)
elif len(prompt_lf1_caches[i])>0:
lf1_caches[i] = torch.cat(prompt_lf1_caches[i], dim=0)
lf2_caches[i] = torch.cat(prompt_lf2_caches[i], dim=0)
elif len(decode_lf1_caches[i]) > 0:
lf1_caches[i] = torch.cat(decode_lf1_caches[i], dim=0)
lf2_caches[i] = torch.cat(decode_lf2_caches[i], dim=0)
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens)
......@@ -791,6 +877,14 @@ class ModelRunner:
metadata_dict = broadcast_tensor_dict(src=0)
decode_attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
bsz = num_prefills + num_decode_tokens
lf1_caches = [torch.zeros((bsz, self.hidden_size, 1, 1), dtype=self.model_config.dtype, device=self.device) for i in range(self.num_layers)]
lf2_caches = [torch.zeros((bsz, self.hidden_size // 2, 1, 1), dtype=self.model_config.dtype, device=self.device) for i in range(self.num_layers)]
prefill_reqs = []
decode_reqs = []
# for b in range(bsz):
# prefill_reqs.append(prefill_attn_metadata)
# decode_reqs.append(decode_attn_metadata)
attn_metadata = AttentionMetadata(
num_prefills=num_prefills,
......@@ -801,20 +895,37 @@ class ModelRunner:
decode_metadata=decode_attn_metadata,
kv_cache_dtype=self.kv_cache_dtype,
)
return (input_tokens, input_positions, attn_metadata,
sampling_metadata, lora_requests, lora_mapping,
multi_modal_input)
multi_modal_input, lf1_caches, lf2_caches, prefill_reqs, decode_reqs)
def update_lf_caches(self, prefill_reqs: List[PreparePromptMetadata], decode_reqs: List[PrepareDecodeMetadata], lf1_caches: List[torch.Tensor], lf2_caches: List[torch.Tensor]):
start_idx = 0
if len(prefill_reqs) > 0 :
for seq_group_metadata in prefill_reqs:
num_seqs = len(seq_group_metadata.seq_data.keys())
for i in range(self.num_layers):
for j in range(len(seq_group_metadata.lf1_caches[i])):
seq_group_metadata.lf1_caches[i][j].copy_(lf1_caches[i][start_idx+j:start_idx+j+1].detach().clone())
seq_group_metadata.lf2_caches[i][j].copy_(lf2_caches[i][start_idx+j:start_idx+j+1].detach().clone())
start_idx += num_seqs
if len(decode_reqs) > 0 :
for seq_group_metadata in decode_reqs:
num_seqs = len(seq_group_metadata.seq_data.keys())
for i in range(self.num_layers):
for j in range(len(seq_group_metadata.lf1_caches[i])):
seq_group_metadata.lf1_caches[i][j].copy_(lf1_caches[i][start_idx+j:start_idx+j+1].detach().clone())
seq_group_metadata.lf2_caches[i][j].copy_(lf2_caches[i][start_idx+j:start_idx+j+1].detach().clone())
start_idx += num_seqs
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[torch.Tensor],
lf_caches: List[Tuple[torch.Tensor, torch.Tensor]] = None
update_lf_caches: bool = False,
) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata,
lora_requests, lora_mapping, multi_modal_input
lora_requests, lora_mapping, multi_modal_input, lf1_caches, lf2_caches, prefill_reqs, decode_reqs
) = self.prepare_input_tensors(seq_group_metadata_list)
if self.lora_config:
......@@ -834,13 +945,15 @@ class ModelRunner:
"kv_caches": kv_caches,
"attn_metadata": attn_metadata,
}
if lf_caches != None:
batch_size = attn_metadata.num_prefills + attn_metadata.num_decode_tokens
execute_model_kwargs.update({'lf_caches': [(lf1_cache[:batch_size], lf2_cache[:batch_size]) for (lf1_cache, lf2_cache) in lf_caches] if lf_caches[0] != (None, None) else lf_caches})
if self.model_config.hf_config.model_type == 'yuan':
execute_model_kwargs.update({'lf1_caches': lf1_caches})
execute_model_kwargs.update({'lf2_caches': lf2_caches})
if self.vision_language_config:
execute_model_kwargs.update({"image_input": multi_modal_input})
hidden_states = model_executable(**execute_model_kwargs)
if self.model_config.hf_config.model_type == 'yuan' and update_lf_caches and (len(prefill_reqs) +len(decode_reqs)) > 0:
self.update_lf_caches(prefill_reqs, decode_reqs, lf1_caches, lf2_caches)
# Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata)
# Only perform sampling in the driver worker.
......@@ -854,7 +967,7 @@ class ModelRunner:
return output
@torch.inference_mode()
def profile_run(self, use_lf_caches=False) -> None:
def profile_run(self) -> None:
# Enable top-k sampling to reflect the accurate memory usage.
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
......@@ -901,6 +1014,8 @@ class ModelRunner:
(group_id < max_num_batched_tokens % max_num_seqs))
seq_data, fake_multi_modal_input = _prepare_fake_inputs(
seq_len, self.vision_language_config)
lf1_caches = [[torch.zeros((1, self.hidden_size, 1, 1), dtype=self.model_config.dtype, device=self.device)] for i in range(self.num_layers)]
lf2_caches = [[torch.zeros((1, self.hidden_size // 2, 1, 1), dtype=self.model_config.dtype, device=self.device)] for i in range(self.num_layers)]
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
......@@ -910,15 +1025,16 @@ class ModelRunner:
lora_request=dummy_lora_requests_per_seq[group_id]
if dummy_lora_requests_per_seq else None,
multi_modal_data=fake_multi_modal_input,
lf1_caches=lf1_caches,
lf2_caches=lf2_caches,
)
seqs.append(seq)
# Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
lf_caches = [(None, None)] * num_layers if use_lf_caches else None
self.execute_model(seqs, kv_caches, lf_caches)
self.execute_model(seqs, kv_caches)
torch.cuda.synchronize()
return
......@@ -949,7 +1065,7 @@ class ModelRunner:
return self.lora_manager.list_loras()
@torch.inference_mode()
def capture_model(self, kv_caches: List[torch.Tensor], lf_caches: List[LFCache] = None) -> None:
def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
"""Cuda graph capture a model.
Note that CUDA graph's performance gain is negligible if number
......@@ -1004,6 +1120,8 @@ class ModelRunner:
# memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list):
# Create dummy attn_metadata.
lf1_caches = [torch.zeros((batch_size, self.hidden_size, 1, 1), dtype=self.model_config.dtype, device=self.device) for i in range(self.num_layers)]
lf2_caches = [torch.zeros((batch_size, self.hidden_size//2, 1, 1), dtype=self.model_config.dtype, device=self.device) for i in range(self.num_layers)]
decode_metadata = self.attn_backend.make_metadata(
is_prompt=False,
prompt_lens=None,
......@@ -1035,23 +1153,16 @@ class ModelRunner:
self.set_active_loras(set(), lora_mapping)
graph_runner = CUDAGraphRunner(self.model)
if lf_caches != None:
graph_runner.capture(
input_tokens[:batch_size],
input_positions[:batch_size],
kv_caches,
[(lf1_cache[:batch_size], lf2_cache[:batch_size]) for (lf1_cache, lf2_cache) in lf_caches],
attn_metadata,
memory_pool=self.graph_memory_pool,
)
else:
graph_runner.capture(
input_tokens[:batch_size],
input_positions[:batch_size],
kv_caches,
attn_metadata,
memory_pool=self.graph_memory_pool,
)
graph_runner.capture(
input_tokens[:batch_size],
input_positions[:batch_size],
kv_caches,
lf1_caches,
lf2_caches,
attn_metadata,
memory_pool=self.graph_memory_pool,
model_type=self.model_config.hf_config.model_type
)
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[batch_size] = graph_runner
......@@ -1088,9 +1199,11 @@ class CUDAGraphRunner:
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
lf_caches: List[LFCache],
lf1_caches: List[torch.Tensor],
lf2_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
memory_pool,
model_type: str,
**kwargs,
) -> None:
assert self.graph is None
......@@ -1098,11 +1211,13 @@ class CUDAGraphRunner:
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
with _maybe_pynccl():
if lf_caches == None:
if model_type == 'yuan':
self.model(
input_ids,
positions,
kv_caches,
lf1_caches,
lf2_caches,
attn_metadata,
**kwargs,
)
......@@ -1111,7 +1226,6 @@ class CUDAGraphRunner:
input_ids,
positions,
kv_caches,
lf_caches,
attn_metadata,
**kwargs,
)
......@@ -1123,12 +1237,13 @@ class CUDAGraphRunner:
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117
with _maybe_pynccl():
if lf_caches != None:
if model_type=='yuan':
hidden_states = self.model(
input_ids,
positions,
kv_caches,
lf_caches,
lf1_caches,
lf2_caches,
attn_metadata,
**kwargs,
)
......@@ -1147,12 +1262,12 @@ class CUDAGraphRunner:
"input_ids": input_ids,
"positions": positions,
"kv_caches": kv_caches,
"lf1_caches": lf1_caches,
"lf2_caches": lf2_caches,
"slot_mapping": attn_metadata.slot_mapping,
"context_lens": attn_metadata.decode_metadata.context_lens,
"block_tables": attn_metadata.decode_metadata.block_tables,
}
if lf_caches != None:
self.input_buffers.update({"lf_caches": lf_caches})
self.output_buffers = {"hidden_states": hidden_states}
return
......@@ -1161,7 +1276,6 @@ class CUDAGraphRunner:
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
lf_caches: List[Tuple[torch.Tensor, torch.Tensor]],
attn_metadata: AttentionMetadata,
**kwargs,
) -> torch.Tensor:
......
......@@ -131,10 +131,7 @@ class Worker(WorkerBase):
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
if self.model_config.hf_config.model_type == 'yuan':
self.model_runner.profile_run(use_lf_caches=True)
else:
self.model_runner.profile_run()
self.model_runner.profile_run()
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
......@@ -182,12 +179,11 @@ class Worker(WorkerBase):
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
self.parallel_config)
self.gpu_cache = self.cache_engine.gpu_cache
self.lf_gpu_cache = self.cache_engine.lf_gpu_cache if self.model_config.hf_config.model_type == 'yuan' else None
self.model_runner.set_block_size(self.cache_engine.block_size)
def _warm_up_model(self) -> None:
if not self.model_config.enforce_eager:
self.model_runner.capture_model(self.gpu_cache, self.lf_gpu_cache)
self.model_runner.capture_model(self.gpu_cache)
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
......@@ -243,7 +239,7 @@ class Worker(WorkerBase):
if num_seq_groups == 0:
return []
output = self.model_runner.execute_model(seq_group_metadata_list,
self.gpu_cache, self.lf_gpu_cache)
self.gpu_cache, update_lf_caches=True)
# Worker only supports single-step execution. Wrap the output in a list
# to conform to interface.
......
import os
from vllm import LLM, SamplingParams
import time
import argparse
import os
from transformers import LlamaTokenizer
from vllm import LLM, SamplingParams
## params
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', default='', help='model path')
args = parser.parse_args()
model_path = args.model_path
tokenizer = LlamaTokenizer.from_pretrained(model_path, add_eos_token=False, add_bos_token=False, eos_token='<eod>')
tokenizer = LlamaTokenizer.from_pretrained('/mnt/beegfs2/Yuan2-M32-HF/', add_eos_token=False, add_bos_token=False, eos_token='<eod>')
tokenizer.add_tokens(['<sep>', '<pad>', '<mask>', '<predict>', '<FIM_SUFFIX>', '<FIM_PREFIX>', '<FIM_MIDDLE>','<commit_before>','<commit_msg>','<commit_after>','<jupyter_start>','<jupyter_text>','<jupyter_code>','<jupyter_output>','<empty_output>'], special_tokens=True)
prompts = ["写一篇春游作文"]
sampling_params = SamplingParams(max_tokens=300, temperature=1, top_p=0, top_k=1, min_p=0.0, length_penalty=1.0, repetition_penalty=1.0, stop="<eod>", )
## init model
llm = LLM(model=model_path, trust_remote_code=True, tensor_parallel_size=8, gpu_memory_utilization=0.8, disable_custom_all_reduce=True, max_num_seqs=1)
## inference
llm = LLM(model="/mnt/beegfs2/Yuan2-M32-HF/", trust_remote_code=True, tensor_parallel_size=8, gpu_memory_utilization=0.8, disable_custom_all_reduce=True, max_num_seqs=1)
start_time = time.time()
outputs = llm.generate(prompts, sampling_params)
end_time = time.time()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment