Commit 6a583c2f authored by chenych's avatar chenych
Browse files

update dtk to 24.04.1 and modify README

parent 7d576a9a
...@@ -181,7 +181,7 @@ class LLMEngine: ...@@ -181,7 +181,7 @@ class LLMEngine:
# Create the scheduler. # Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of # NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor. # GPU and CPU blocks, which are profiled in the distributed executor.
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config) self.scheduler = Scheduler(model_config, parallel_config, scheduler_config, cache_config, lora_config)
# Metric Logging. # Metric Logging.
if self.log_stats: if self.log_stats:
...@@ -204,6 +204,13 @@ class LLMEngine: ...@@ -204,6 +204,13 @@ class LLMEngine:
self.get_tokenizer_for_seq, self.get_tokenizer_for_seq,
), ),
)) ))
self.num_layers = model_config.get_num_layers(parallel_config)
self.total_num_heads = model_config.hf_config.num_attention_heads
self.head_size = model_config.get_head_size()
self.hidden_size = self.head_size * self.total_num_heads
self.device_config = (device_config if device_config is not None else DeviceConfig())
self.device = self.device_config.device
self.dtype = self.model_config.dtype
def _initialize_kv_caches(self) -> None: def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s). """Initialize the KV cache in the worker(s).
...@@ -385,7 +392,8 @@ class LLMEngine: ...@@ -385,7 +392,8 @@ class LLMEngine:
lora_request).eos_token_id lora_request).eos_token_id
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
eos_token_id, lora_request) eos_token_id, lora_request)
if self.model_config.hf_config.model_type == 'yuan':
seq.create_lf_caches(self.hidden_size, self.num_layers, self.device, self.dtype)
# Defensive copy of SamplingParams, which are used by the sampler, # Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects # this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone() sampling_params = sampling_params.clone()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment