Unverified Commit 1a6fcad4 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Improve `TransformersModel` UX (#12785)

parent 56534cd5
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
"""Wrapper around `transformers` models""" """Wrapper around `transformers` models"""
import re import re
from typing import Iterable, Optional, Union from typing import Iterable, Literal, Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -72,15 +72,24 @@ def vllm_flash_attention_forward( ...@@ -72,15 +72,24 @@ def vllm_flash_attention_forward(
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
logger.debug("%s: %s -> %s", name, old_module, new_module)
def replace_linear_class( def replace_linear_class(
linear: nn.Linear, linear: nn.Linear,
style: str, style: Literal["colwise", "rowwise"],
quant_config=None) -> Union[ColumnParallelLinear, RowParallelLinear]: quant_config=None) -> Union[ColumnParallelLinear, RowParallelLinear]:
""" """
In model configurations, we use a neutral type (string) to specify parallel Replace nn.Linear with one of vLLM's tensor parallel linear classes.
styles, here we use it to translate nn.Linear into vllm-style tp Linear.
`quant_config` is not yet supported.
Quant config is not supported yet Args:
linear (nn.Linear): `nn.Linear` to be replaced.
style (str): Tensor parallel style of the new linear, e.g. "colwise".
quant_config (QuantConfig): Quantization config for the new linear.
Returns:
Union[ColumnParallelLinear, RowParallelLinear]: The new linear.
""" """
if not isinstance(style, str): if not isinstance(style, str):
...@@ -93,7 +102,10 @@ def replace_linear_class( ...@@ -93,7 +102,10 @@ def replace_linear_class(
}.get(style) }.get(style)
if vllm_linear_cls is None: if vllm_linear_cls is None:
raise ValueError(f"Unsupported parallel style value: {style}") logger.warning(
"Unsupported parallel style value: %s. "
"This layer will not be tensor parallelized.", style)
return linear
class HFCompatibleLinear(vllm_linear_cls): class HFCompatibleLinear(vllm_linear_cls):
""" """
...@@ -119,25 +131,24 @@ class TransformersModel(nn.Module): ...@@ -119,25 +131,24 @@ class TransformersModel(nn.Module):
super().__init__() super().__init__()
logger.info("Using Transformers backend.") logger.info("Using Transformers backend.")
self.vllm_config = vllm_config
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.quant_config = quant_config
self.config = config self.config = config
self.quant_config = quant_config
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
self.model: PreTrainedModel = AutoModel.from_config( self.model: PreTrainedModel = AutoModel.from_config(
self.config, self.config,
attn_implementation="vllm", attn_implementation="vllm",
torch_dtype=vllm_config.model_config.dtype,
trust_remote_code=vllm_config.model_config.trust_remote_code, trust_remote_code=vllm_config.model_config.trust_remote_code,
) )
prefix = self.model.base_model_prefix prefix = self.model.base_model_prefix
# MLP modifications # MLP modifications
self.tensor_parallelize(self.model) self.apply_base_model_tp_plan(self.model)
# Attention modifications (assumes 1 attention op per hidden layer) # Attention modifications (assumes 1 attention op per hidden layer)
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
...@@ -170,13 +181,13 @@ class TransformersModel(nn.Module): ...@@ -170,13 +181,13 @@ class TransformersModel(nn.Module):
config.vocab_size, logit_scale) config.vocab_size, logit_scale)
self.sampler = get_sampler() self.sampler = get_sampler()
def log_replacement(self, name: str, old_module: nn.Module, def apply_base_model_tp_plan(self, module: nn.Module, prefix: str = ""):
new_module: nn.Module): """
logger.debug("%s: %s -> %s", name, old_module, new_module) Apply the base model tensor parallelization plan to a module.
Currently only supports linear layers.
def tensor_parallelize(self, module: nn.Module, prefix: str = ""): """
if (self.config.base_model_tp_plan is None if (self.config.base_model_tp_plan is None
and self.vllm_config.parallel_config.tensor_parallel_size > 1): and get_tensor_model_parallel_world_size() > 1):
raise ValueError( raise ValueError(
"Trying to run tensor parallelization but the model does not " "Trying to run tensor parallelization but the model does not "
"support it yet!") "support it yet!")
...@@ -189,9 +200,9 @@ class TransformersModel(nn.Module): ...@@ -189,9 +200,9 @@ class TransformersModel(nn.Module):
new_module = replace_linear_class(child_module, style, new_module = replace_linear_class(child_module, style,
self.quant_config) self.quant_config)
setattr(module, child_name, new_module) setattr(module, child_name, new_module)
self.log_replacement(qual_name, child_module, new_module) log_replacement(qual_name, child_module, new_module)
else: else:
self.tensor_parallelize(child_module, prefix=qual_name) self.apply_base_model_tp_plan(child_module, prefix=qual_name)
def replace_vocab_embed_class(self, module: nn.Module): def replace_vocab_embed_class(self, module: nn.Module):
# Use native set input embeddings # Use native set input embeddings
...@@ -201,8 +212,8 @@ class TransformersModel(nn.Module): ...@@ -201,8 +212,8 @@ class TransformersModel(nn.Module):
org_num_embeddings=self.config.vocab_size, org_num_embeddings=self.config.vocab_size,
quant_config=None, quant_config=None,
) )
self.log_replacement("input embedding", log_replacement("input embedding", self.model.get_input_embeddings(),
self.model.get_input_embeddings(), new_module) new_module)
self.model.set_input_embeddings(new_module) self.model.set_input_embeddings(new_module)
def forward( def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment