Commit 6a583c2f authored by chenych's avatar chenych
Browse files

update dtk to 24.04.1 and modify README

parent 7d576a9a
......@@ -181,7 +181,7 @@ class LLMEngine:
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
self.scheduler = Scheduler(model_config, parallel_config, scheduler_config, cache_config, lora_config)
# Metric Logging.
if self.log_stats:
......@@ -204,6 +204,13 @@ class LLMEngine:
self.get_tokenizer_for_seq,
),
))
self.num_layers = model_config.get_num_layers(parallel_config)
self.total_num_heads = model_config.hf_config.num_attention_heads
self.head_size = model_config.get_head_size()
self.hidden_size = self.head_size * self.total_num_heads
self.device_config = (device_config if device_config is not None else DeviceConfig())
self.device = self.device_config.device
self.dtype = self.model_config.dtype
def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
......@@ -385,7 +392,8 @@ class LLMEngine:
lora_request).eos_token_id
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
eos_token_id, lora_request)
if self.model_config.hf_config.model_type == 'yuan':
seq.create_lf_caches(self.hidden_size, self.num_layers, self.device, self.dtype)
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment