Commit 9e813a0e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.7.2-dev' into v0.7.2-fusion

parents f44e9f9e abf008ef
......@@ -488,11 +488,11 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha is None:
sha = get_sha(vllm_root)
if (major, minor) == ('2', '4'):
version = 'das.opt1.beta.' + sha[:7]
version = 'das.opt2.alpha.' + sha[:7]
# version = 'das.opt1.' + sha[:7]
else:
if (major, minor) == ('2', '4'):
version = 'das.opt1.beta'
version = 'das.opt2.alpha'
# version = 'das.opt1'
......
......@@ -1328,6 +1328,8 @@ class LLMEngine:
while True:
self.sem_m2s.acquire()
if not self.thread_running:
logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop()
break
virtual_engine = 0
......@@ -1438,8 +1440,9 @@ class LLMEngine:
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop()
# logger.debug("Stopping remote worker execution loop.")
# self.model_executor.stop_remote_worker_execution_loop()
self.finish_thread()
return ctx.request_outputs
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
......
......@@ -48,10 +48,13 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
f"is less than than max local gpu count ({cuda_device_count})")
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
if "CUDA_VISIBLE_DEVICES" not in os.environ:
if "CUDA_VISIBLE_DEVICES" or "HIP_VISIBLE_DEVICES" not in os.environ:
update_environment_variables({
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
})
update_environment_variables({
"HIP_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
})
def _init_executor(self) -> None:
......
......@@ -79,10 +79,9 @@ def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", [])
# TODO: support deepseek distillation series models ( 'LlamaForCausalLM', 'Qwen2ForCausalLM' )
# 'Qwen2VLForConditionalGeneration', 'Qwen2_5_VLForConditionalGeneration'
support_nn_architectures = ['QWenLMHeadModel',
'Qwen2MoeForCausalLM', 'ChatGLMModel', 'ChatGLMForConditionalGeneration',
# 'Qwen2_5_VLForConditionalGeneration'
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM',
'Qwen2VLForConditionalGeneration','ChatGLMModel', 'ChatGLMForConditionalGeneration',
'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM',
'MLPSpeculatorPreTrainedModel', 'FalconForCausalLM', 'DeepseekV2ForCausalLM',
'DeepseekV3ForCausalLM', 'DeepSeekMTPModel']
......@@ -202,4 +201,4 @@ def configure_quant_config(quant_config: QuantizationConfig,
logger.warning(
"The model class %s has not defined `packed_modules_mapping`, "
"this may lead to incorrect mapping of quantized or ignored "
"modules", model_class.__name__)
"modules", model_class.__name__)
\ No newline at end of file
......@@ -407,6 +407,11 @@ def safetensors_weights_iterator(
hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
total_count = 0
for st_file in hf_weights_files:
with safe_open(st_file, framework="pt") as f:
total_count += len(f.keys())
current_count = 0
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
for st_file in tqdm(
......@@ -417,7 +422,10 @@ def safetensors_weights_iterator(
):
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
current_count += 1
param = f.get_tensor(name)
param.current_count = current_count
param.total_count = total_count
yield name, param
......
......@@ -450,6 +450,8 @@ class LlamaModel(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
current_count = loaded_weight.current_count
total_count = loaded_weight.total_count
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
......@@ -502,7 +504,7 @@ class LlamaModel(nn.Module):
weight_loader(param, loaded_weight)
loaded_params.add(name)
if self.use_llama_nn and self.quant_method is None :
if self.use_llama_nn and self.quant_method is None and current_count==total_count:
lay_key_words = [
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
......@@ -515,7 +517,8 @@ class LlamaModel(nn.Module):
# qkv_words = "|".join(lay_qkv_words)
# for layername, weight in params_dict.items():
for layername in loaded_params:
# for layername in loaded_params:
for layername in params_dict.keys():
weight = params_dict[layername]
if "lm_head.weight" in layername and weight.shape[1] >= 4096:
lay_key_words.append("lm_head.weight")
......
......@@ -398,6 +398,8 @@ class Qwen2Model(nn.Module):
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
current_count = loaded_weight.current_count
total_count = loaded_weight.total_count
if "rotary_emb.inv_freq" in name:
continue
if (self.quant_config is not None and
......@@ -440,7 +442,7 @@ class Qwen2Model(nn.Module):
weight_loader(param, loaded_weight)
loaded_params.add(name)
if self.use_llama_nn and self.quant_method is None:
if self.use_llama_nn and self.quant_method is None and current_count==total_count:
lay_key_words = [
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
......@@ -456,7 +458,8 @@ class Qwen2Model(nn.Module):
# qkv_bias_words = "|".join(lay_qkv_bias_words)
# for layername, weight in params_dict.items():
for layername in loaded_params:
# for layername in loaded_params:
for layername in params_dict.keys():
weight = params_dict[layername]
if "lm_head.weight" in layername and weight.shape[1] >= 3584:
lay_key_words.append("lm_head.weight")
......
......@@ -1300,17 +1300,36 @@ class HiddenStates(msgspec.Struct, array_like=True,
"""Update hidden states from target model invocation. Only used for
decode steps"""
assert len(seq_group_metadata_list) == len(hidden_states)
self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
# self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
# self.hidden_states = torch.cat([self.hidden_states, hidden_states])
# if self.second_last_token_hidden_states is not None:
# # Adding dummy hidden_states to this to maintain same shape
# self.second_last_token_hidden_states = torch.cat([
# self.second_last_token_hidden_states,
# torch.zeros_like(hidden_states)
# if second_last_token_hidden_states is None else
# second_last_token_hidden_states
# ])
seq_ids = get_all_seq_ids(seq_group_metadata_list)
diff_seq_ids = [item for item in self._seq_ids if item not in seq_ids]
index = [self._seq_ids.index(seq_id) for seq_id in diff_seq_ids]
self._seq_ids = diff_seq_ids
self.hidden_states = self.hidden_states[index]
self.hidden_states = torch.cat([self.hidden_states, hidden_states])
if self.second_last_token_hidden_states is not None:
# Adding dummy hidden_states to this to maintain same shape
self.second_last_token_hidden_states = self.second_last_token_hidden_states[index]
self.second_last_token_hidden_states = torch.cat([
self.second_last_token_hidden_states,
torch.zeros_like(hidden_states)
if second_last_token_hidden_states is None else
second_last_token_hidden_states
])
self._seq_ids.extend(seq_ids)
def prune(self,
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
......
......@@ -691,15 +691,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
torch.where(sampler_output.sampled_token_ids -
VLLM_INVALID_TOKEN_ID)[0]]
if not skip_proposer:
if self.previous_hidden_states is None and len(
seq_group_meta_with_hidden):
self.previous_hidden_states = HiddenStates(
hidden_states, seq_group_meta_with_hidden)
elif self.previous_hidden_states and len(
seq_group_meta_with_hidden):
self.previous_hidden_states.update(hidden_states,
seq_group_meta_with_hidden)
if self.previous_hidden_states is None and len(
seq_group_meta_with_hidden):
self.previous_hidden_states = HiddenStates(
hidden_states, seq_group_meta_with_hidden)
elif self.previous_hidden_states and len(
seq_group_meta_with_hidden):
self.previous_hidden_states.update(hidden_states,
seq_group_meta_with_hidden)
# Store logits from target model execution.
if self.tree_decoding:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment