Unverified Commit 1b7b161a authored by Shrey Gupta's avatar Shrey Gupta Committed by GitHub
Browse files

[Feature] models: pass layer prefix to replace_linear_class for per-layer...


[Feature] models: pass layer prefix to replace_linear_class for per-layer quantization routing. Addresses #23239 (#23556)
Signed-off-by: default avatarShrey Gupta <shreyg1303@gmail.com>
parent a69693e3
...@@ -408,13 +408,17 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -408,13 +408,17 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
parent, attr_name = self._get_parent_and_attr(vit, name) parent, attr_name = self._get_parent_and_attr(vit, name)
if isinstance(parent, timm.layers.Mlp) and attr_name == "fc1": if isinstance(parent, timm.layers.Mlp) and attr_name == "fc1":
new_linear = replace_linear_class(module, "colwise", new_linear = replace_linear_class(module,
quant_config) "colwise",
quant_config,
prefix=name)
setattr(parent, attr_name, new_linear) setattr(parent, attr_name, new_linear)
elif isinstance(parent, elif isinstance(parent,
timm.layers.Mlp) and attr_name == "fc2": timm.layers.Mlp) and attr_name == "fc2":
new_linear = replace_linear_class(module, "rowwise", new_linear = replace_linear_class(module,
quant_config) "rowwise",
quant_config,
prefix=name)
setattr(parent, attr_name, new_linear) setattr(parent, attr_name, new_linear)
return vit return vit
......
...@@ -106,8 +106,11 @@ def can_enable_torch_compile(vllm_config: VllmConfig) -> bool: ...@@ -106,8 +106,11 @@ def can_enable_torch_compile(vllm_config: VllmConfig) -> bool:
def replace_linear_class( def replace_linear_class(
linear: nn.Linear, style: Literal["colwise", "rowwise"], linear: nn.Linear,
quant_config: QuantizationConfig style: Literal["colwise", "rowwise"],
quant_config: QuantizationConfig,
*,
prefix: str = "",
) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]: ) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]:
""" """
Replace nn.Linear with one of vLLM's tensor parallel linear classes. Replace nn.Linear with one of vLLM's tensor parallel linear classes.
...@@ -141,6 +144,7 @@ def replace_linear_class( ...@@ -141,6 +144,7 @@ def replace_linear_class(
output_size=linear.out_features, output_size=linear.out_features,
bias=linear.bias is not None, bias=linear.bias is not None,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix,
return_bias=False, return_bias=False,
**vllm_linear_kwargs, **vllm_linear_kwargs,
) )
...@@ -557,8 +561,10 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): ...@@ -557,8 +561,10 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
generator = (p for p in tp_plan if re.match(p, qual_name)) generator = (p for p in tp_plan if re.match(p, qual_name))
pattern = next(generator, None) pattern = next(generator, None)
style = tp_plan.get(pattern, "replicate") style = tp_plan.get(pattern, "replicate")
new_module = replace_linear_class(child_module, style, new_module = replace_linear_class(child_module,
self.quant_config) style,
self.quant_config,
prefix=qual_name)
setattr(module, child_name, new_module) setattr(module, child_name, new_module)
log_replacement(qual_name, child_module, new_module) log_replacement(qual_name, child_module, new_module)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment