Commit a130cf33 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.3.3' into vllm-v0.3.2-dtk23.10 and add gfx

parents a2d181be 82091b86
......@@ -23,6 +23,7 @@ from typing import List, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
......@@ -42,7 +43,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
......@@ -186,7 +186,7 @@ class BaiChuanAttention(nn.Module):
class BaiChuanDecoderLayer(nn.Module):
def __init__(self,
config: BaiChuanConfig,
config: PretrainedConfig,
position_embedding: str,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
......@@ -245,7 +245,7 @@ class BaiChuanDecoderLayer(nn.Module):
class BaiChuanModel(nn.Module):
def __init__(self,
config: BaiChuanConfig,
config: PretrainedConfig,
position_embedding: str,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
......
......@@ -41,7 +41,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
Based on the llama executor.
The main difference is that DeciLM uses Variable Grouped Query Attention.
The constant number of GQA heads in the decoder is overriden with a value
The constant number of GQA heads in the decoder is overridden with a value
per layer.
Usually, in the HuggingFace implementation, instead of
......
......@@ -20,10 +20,13 @@ import torch
from torch import nn
from transformers import GemmaConfig
from vllm.config import LoRAConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
......@@ -40,21 +43,6 @@ from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * (1 + self.weight)
class GemmaMLP(nn.Module):
def __init__(
......@@ -64,27 +52,21 @@ class GemmaMLP(nn.Module):
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.gate_proj = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=False,
linear_method=linear_method)
self.up_proj = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=False,
linear_method=linear_method)
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
linear_method=linear_method)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
linear_method=linear_method)
self.act_fn = nn.GELU()
self.act_fn = GeluAndMul()
def forward(self, x):
gate, _ = self.gate_proj(x)
gate = self.act_fn(gate)
up, _ = self.up_proj(x)
fuse = gate * up
outputs, _ = self.down_proj(fuse)
return outputs
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class GemmaAttention(nn.Module):
......@@ -185,10 +167,10 @@ class GemmaDecoderLayer(nn.Module):
intermediate_size=config.intermediate_size,
linear_method=linear_method,
)
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
......@@ -196,25 +178,27 @@ class GemmaDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
return hidden_states, residual
class GemmaModel(nn.Module):
......@@ -235,7 +219,7 @@ class GemmaModel(nn.Module):
GemmaDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers)
])
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
......@@ -246,27 +230,53 @@ class GemmaModel(nn.Module):
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Normalize the embedding by sqrt(hidden_size)
hidden_states = hidden_states * (self.config.hidden_size**0.5)
hidden_states *= self.config.hidden_size**0.5
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states = layer(
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
residual,
)
hidden_states = self.norm(hidden_states)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class GemmaForCausalLM(nn.Module):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
# Gemma does not apply LoRA to the embedding layer.
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
config: GemmaConfig,
linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
del lora_config # Unused.
super().__init__()
self.config = config
self.linear_method = linear_method
......@@ -304,6 +314,8 @@ class GemmaForCausalLM(nn.Module):
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params = set()
......@@ -318,9 +330,10 @@ class GemmaForCausalLM(nn.Module):
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra layer for lora models.
if "lm_head" in name:
continue
# GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight.
if "norm.weight" in name:
loaded_weight += 1.0
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
......@@ -329,5 +342,5 @@ class GemmaForCausalLM(nn.Module):
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
raise RuntimeError(
f"Some weights are not initialized from checkpoints: {unloaded_params}"
)
"Some weights are not initialized from checkpoints: "
f"{unloaded_params}")
......@@ -27,6 +27,7 @@ import torch
from torch import nn
from transformers import LlamaConfig
from vllm.config import LoRAConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttention
......@@ -45,7 +46,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.config import LoRAConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
......@@ -92,6 +92,7 @@ class LlamaAttention(nn.Module):
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
bias: bool = False,
sliding_window: Optional[int] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -141,7 +142,8 @@ class LlamaAttention(nn.Module):
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
num_kv_heads=self.num_kv_heads,
sliding_window=sliding_window)
def forward(
self,
......@@ -172,6 +174,7 @@ class LlamaDecoderLayer(nn.Module):
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
sliding_window = getattr(config, "sliding_window", None)
self.self_attn = LlamaAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
......@@ -182,6 +185,7 @@ class LlamaDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
bias=getattr(config, "bias", False),
sliding_window=sliding_window,
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
......
"""Inference-only LLaMA model compatible with HuggingFace weights."""
import os
from typing import List, Optional, Tuple
import torch
from torch import nn
from transformers import LlamaConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class LlamaForCausalLM(nn.Module):
def __init__(
self,
config: LlamaConfig,
linear_method=None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = None
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
with torch.inference_mode():
block_size = self.model.context_buckets[-1]
if input_metadata.is_prompt:
seq_ids = input_metadata.slot_mapping[:, 0] // block_size
else:
seq_ids = input_metadata.block_tables
logits = self.model(input_ids,
cache_ids=positions,
start_ids=seq_ids.flatten())
return logits
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.model.chkpt_model.lm_head,
hidden_states, sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
**kwargs):
from transformers_neuronx.llama.model import LlamaForSampling
split_model_dir = f"{model_name_or_path}-split"
if os.path.isdir(os.path.join(model_name_or_path,
"pytorch_model.bin")):
split_model_dir = model_name_or_path
elif not os.path.exists(f"{model_name_or_path}-split"):
from transformers.models.llama import LlamaForCausalLM
from transformers_neuronx.module import save_pretrained_split
hf_model = LlamaForCausalLM.from_pretrained(model_name_or_path,
low_cpu_mem_usage=True)
save_pretrained_split(hf_model, f"{model_name_or_path}-split")
self.model = LlamaForSampling.from_pretrained(split_model_dir,
**kwargs)
self.model.to_neuron()
......@@ -61,7 +61,9 @@ from vllm.model_executor.weight_utils import (
hf_model_weights_iterator,
)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.olmo import OLMoConfig
# this model must need this dependency
from hf_olmo import OLMoConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
......
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mistral model compatible with HuggingFace weights."""
from typing import List, Optional, Tuple
# https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/modeling_orion.py
# Copyright (c) OrionStar Inc.
# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
"""Inference-only Orion-14B model compatible with HuggingFace weights."""
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch import nn
from transformers import MistralConfig
from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
......@@ -38,19 +20,18 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.config import LoRAConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
class MistralMLP(nn.Module):
class OrionMLP(nn.Module):
def __init__(
self,
......@@ -80,16 +61,18 @@ class MistralMLP(nn.Module):
return x
class MistralAttention(nn.Module):
class OrionAttention(nn.Module):
def __init__(self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
linear_method: Optional[LinearMethodBase] = None,
sliding_window: Optional[int] = None) -> None:
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
......@@ -111,7 +94,7 @@ class MistralAttention(nn.Module):
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.sliding_window = sliding_window
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size,
......@@ -131,14 +114,14 @@ class MistralAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=self.rope_theta,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window)
num_kv_heads=self.num_kv_heads)
def forward(
self,
......@@ -156,35 +139,39 @@ class MistralAttention(nn.Module):
return output
class MistralDecoderLayer(nn.Module):
class OrionDecoderLayer(nn.Module):
def __init__(
self,
config: MistralConfig,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = MistralAttention(
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = OrionAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
sliding_window=config.sliding_window)
self.mlp = MistralMLP(
)
self.mlp = OrionMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
......@@ -195,12 +182,8 @@ class MistralDecoderLayer(nn.Module):
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
......@@ -208,39 +191,36 @@ class MistralDecoderLayer(nn.Module):
input_metadata=input_metadata,
)
hidden_states = residual + hidden_states
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
hidden_states = residual + hidden_states
return hidden_states, None
class MistralModel(nn.Module):
class OrionModel(nn.Module):
def __init__(
self,
config: MistralConfig,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.layers = nn.ModuleList([
MistralDecoderLayer(config, linear_method)
OrionDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
......@@ -260,63 +240,23 @@ class MistralModel(nn.Module):
input_metadata,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
hidden_states = self.norm(hidden_states)
return hidden_states
class MistralForCausalLM(nn.Module):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
class OrionForCausalLM(nn.Module):
def __init__(
self,
config: MistralConfig,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = MistralModel(config,
linear_method,
lora_config=lora_config)
unpadded_vocab_size = config.vocab_size
if lora_config:
unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
self.sampler = Sampler(unpadded_vocab_size, config.vocab_size)
self.model = OrionModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
......@@ -356,6 +296,11 @@ class MistralForCausalLM(nn.Module):
model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
......
......@@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
......@@ -27,7 +28,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.qwen import QWenConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
......@@ -127,7 +127,7 @@ class QWenBlock(nn.Module):
def __init__(
self,
config: QWenConfig,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
......@@ -179,7 +179,7 @@ class QWenModel(nn.Module):
def __init__(
self,
config: QWenConfig,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
......@@ -222,7 +222,7 @@ class QWenLMHeadModel(nn.Module):
def __init__(
self,
config: QWenConfig,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
......
......@@ -94,7 +94,9 @@ class StablelmAttention(nn.Module):
1, self.total_num_key_value_heads // tp_size)
self.head_dim = self.hidden_size // self.total_num_heads
self.max_position_embeddings = config.max_position_embeddings
self.rotary_ndims = int(self.head_dim * self.config.rope_pct)
rope_pct = getattr(config, "rope_pct",
getattr(config, "partial_rotary_factor", 1))
self.rotary_ndims = int(self.head_dim * rope_pct)
self.scaling = self.head_dim**-0.5
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_key_value_heads * self.head_dim
......@@ -114,7 +116,6 @@ class StablelmAttention(nn.Module):
self.hidden_size,
bias=False,
linear_method=linear_method)
self.rotary_ndims = int(self.head_dim * self.config.rope_pct)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_ndims,
......@@ -152,10 +153,11 @@ class StablelmDecoderLayer(nn.Module):
super().__init__()
self.self_attn = StablelmAttention(config)
self.mlp = StablelmMLP(config, linear_method)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.norm_eps)
norm_eps = getattr(config, "norm_eps",
getattr(config, "layer_norm_eps", 1e-05))
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.norm_eps)
eps=norm_eps)
def forward(
self,
......@@ -199,7 +201,9 @@ class StableLMEpochModel(nn.Module):
StablelmDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers)
])
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
norm_eps = getattr(config, "norm_eps",
getattr(config, "layer_norm_eps", 1e-05))
self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
def forward(
self,
......
# coding=utf-8
# Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Starcoder2 model."""
from typing import List, Optional, Tuple
import torch
from torch import nn
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
from vllm.model_executor.parallel_utils.parallel_state import get_tensor_model_parallel_world_size
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
try:
from transformers import Starcoder2Config
except ImportError:
# fallback to PretrainedConfig
# NOTE: Please install transformers from source or use transformers>=4.39.0
from transformers import PretrainedConfig as Starcoder2Config
KVCache = Tuple[torch.Tensor, torch.Tensor]
class Starcoder2Attention(nn.Module):
def __init__(self,
config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = config.num_key_value_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = self.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = config.rope_theta
self.max_position_embeddings = config.max_position_embeddings
self.use_bias = config.use_bias
self.sliding_window = config.sliding_window
self.qkv_proj = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=self.use_bias,
linear_method=linear_method,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=self.use_bias,
linear_method=linear_method,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=int(self.rope_theta),
is_neox_style=True,
)
self.attn = PagedAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output)
return output
class Starcoder2MLP(nn.Module):
def __init__(self,
config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
self.c_fc = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
bias=config.use_bias,
linear_method=linear_method,
)
self.c_proj = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
bias=config.use_bias,
linear_method=linear_method,
)
self.act = get_act_fn(config.hidden_act,
intermediate_size=config.intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.c_proj(hidden_states)
return hidden_states
class Starcoder2DecoderLayer(nn.Module):
def __init__(self,
config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Starcoder2Attention(config,
linear_method=linear_method)
self.mlp = Starcoder2MLP(config, linear_method=linear_method)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.norm_epsilon)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.norm_epsilon)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Starcoder2Model(nn.Module):
def __init__(self,
config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
# TODO: consider padding_idx (currently removed)
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
Starcoder2DecoderLayer(config, linear_method=linear_method)
for _ in range(config.num_hidden_layers)
])
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states = layer(positions, hidden_states, kv_caches[i],
input_metadata)
hidden_states = self.norm(hidden_states)
return hidden_states
class Starcoder2ForCausalLM(nn.Module):
def __init__(self,
config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
self.config = config
self.model = Starcoder2Model(config, linear_method=linear_method)
self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size
if config.tie_word_embeddings:
self.lm_head_weight = self.model.embed_tokens.weight
else:
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
)
self.lm_head_weight = self.lm_head.weight
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata)
return hidden_states
def sample(
self,
hidden_states: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
"""Utilities for selecting and loading models."""
from typing import Type
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config import ModelConfig, DeviceConfig
from vllm.model_executor.models import ModelRegistry
TORCH_DTYPE_TO_NEURON_AMP = {
"auto": "f32",
"half": "f16",
"float16": "f16",
"bfloat16": "bf16",
"float": "f32",
"float32": "f32",
torch.float16: "f16",
torch.bfloat16: "bf16",
torch.float32: "f32",
}
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", [])
for arch in architectures:
model_cls = ModelRegistry.load_model_cls(arch)
if model_cls is not None:
return model_cls
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
**kwargs) -> nn.Module:
from transformers_neuronx.config import NeuronConfig, ContinuousBatchingConfig
parallel_config = kwargs.get("parallel_config")
scheduler_config = kwargs.get("scheduler_config")
model_class = _get_model_architecture(model_config.hf_config)
linear_method = None
# Create a model instance.
model = model_class(model_config.hf_config, linear_method)
continuous_batching_config = ContinuousBatchingConfig(
batch_size_for_shared_caches=scheduler_config.max_num_seqs)
neuron_config = NeuronConfig(
continuous_batching=continuous_batching_config)
# Load the weights from the cached or downloaded files.
model.load_weights(
model_config.model,
model_config.download_dir,
model_config.load_format,
model_config.revision,
tp_degree=parallel_config.neuron_tp_degree,
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
neuron_config=neuron_config,
context_length_estimate=[scheduler_config.max_model_len],
n_positions=[scheduler_config.max_model_len],
batch_size=scheduler_config.max_num_seqs)
return model.eval()
......@@ -36,14 +36,14 @@ def init_custom_ar() -> None:
if world_size not in _SUPPORTED_WORLD_SIZES:
logger.warn(
"Custom allreduce is disabled due to an unsupported world size: "
"%d. Supported world sizes: %s. To slience this warning, specify"
"%d. Supported world sizes: %s. To silence this warning, specify"
"disable_custom_all_reduce=True explicitly.", world_size,
str(_SUPPORTED_WORLD_SIZES))
return
if not _can_p2p(rank, world_size):
logger.warn(
"Custom allreduce is disabled because your platform lacks GPU P2P"
" capability. To slience this warning, specify"
" capability. To silence this warning, specify"
"disable_custom_all_reduce=True explicitly.")
return
_CA_HANDLE = CustomAllreduce(rank, world_size)
......
......@@ -189,7 +189,7 @@ def get_pipeline_model_parallel_next_rank():
def get_pipeline_model_parallel_prev_rank():
"""Return the global rank that preceeds the caller in the pipeline"""
"""Return the global rank that precedes the caller in the pipeline"""
assert _PIPELINE_GLOBAL_RANKS is not None, (
"Pipeline parallel group is not initialized")
rank_in_pipeline = get_pipeline_model_parallel_rank()
......
......@@ -5,7 +5,7 @@ import torch
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData
from vllm.utils import in_wsl
from vllm.utils import in_wsl, is_neuron
_SAMPLING_EPS = 1e-5
......@@ -155,7 +155,7 @@ class SamplingTensors:
dtype: torch.dtype) -> "SamplingTensors":
# Note that the performance will be very bad without
# pinned memory.
pin_memory = not in_wsl()
pin_memory = not in_wsl() and not is_neuron()
prompt_max_len = max(len(tokens) for tokens in prompt_tokens)
prompt_padded_tokens = [
tokens + [vocab_size] * (prompt_max_len - len(tokens))
......
"""Utils for model executor."""
import random
import importlib
from typing import Any, Dict, Optional
import numpy as np
import torch
from vllm.config import DeviceConfig, ModelConfig
DEVICE_TO_MODEL_LOADER_MAP = {
"cuda": "model_loader",
"neuron": "neuron_model_loader",
}
def set_random_seed(seed: int) -> None:
random.seed(seed)
......@@ -33,3 +41,12 @@ def set_weight_attrs(
assert not hasattr(
weight, key), (f"Overwriting existing tensor attribute: {key}")
setattr(weight, key, value)
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
**kwargs) -> torch.nn.Module:
model_loader_module = DEVICE_TO_MODEL_LOADER_MAP[device_config.device_type]
imported_model_loader = importlib.import_module(
f"vllm.model_executor.{model_loader_module}")
get_model_fn = imported_model_loader.get_model
return get_model_fn(model_config, device_config, **kwargs)
"""Sampling parameters for text generation."""
import copy
from enum import IntEnum
from functools import cached_property
from typing import Callable, List, Optional, Union
......@@ -237,6 +238,20 @@ class SamplingParams:
return SamplingType.RANDOM_SEED
return SamplingType.RANDOM
def clone(self) -> "SamplingParams":
"""Deep copy excluding LogitsProcessor objects.
LogitsProcessor objects are excluded because they may contain an
arbitrary, nontrivial amount of data.
See https://github.com/vllm-project/vllm/issues/3087
"""
logit_processor_refs = None if self.logits_processors is None else {
id(lp): lp
for lp in self.logits_processors
}
return copy.deepcopy(self, memo=logit_processor_refs)
def __repr__(self) -> str:
return (
f"SamplingParams(n={self.n}, "
......
......@@ -5,12 +5,11 @@ from transformers import AutoConfig, PretrainedConfig
from vllm.transformers_utils.configs import *
_CONFIG_REGISTRY = {
"baichuan": BaiChuanConfig,
"chatglm": ChatGLMConfig,
"mpt": MPTConfig,
"qwen": QWenConfig,
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
"starcoder2": Starcoder2Config,
}
......@@ -18,6 +17,15 @@ def get_config(model: str,
trust_remote_code: bool,
revision: Optional[str] = None,
code_revision: Optional[str] = None) -> PretrainedConfig:
# FIXME(woosuk): This is a temporary fix for StarCoder2.
# Remove this when the model is supported by HuggingFace transformers.
if "bigcode" in model and "starcoder2" in model:
config_class = _CONFIG_REGISTRY["starcoder2"]
config = config_class.from_pretrained(model,
revision=revision,
code_revision=code_revision)
return config
try:
config = AutoConfig.from_pretrained(
model,
......
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.olmo import OLMoConfig
from vllm.transformers_utils.configs.qwen import QWenConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library.
from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.starcoder2 import Starcoder2Config
__all__ = [
"BaiChuanConfig",
"ChatGLMConfig",
"MPTConfig",
"OLMoConfig",
"QWenConfig",
"RWConfig",
"Starcoder2Config",
]
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers.configuration_utils import PretrainedConfig
class BaiChuanConfig(PretrainedConfig):
model_type = "baichuan"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=64000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
hidden_act="silu",
max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# coding=utf-8
# adapted from https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/configuration_olmo.py
"""OLMo configuration"""
from transformers import PretrainedConfig
class OLMoConfig(PretrainedConfig):
model_type = 'olmo'
attribute_map = {
'num_attention_heads': 'n_heads',
'hidden_size': 'd_model',
'num_hidden_layers': 'n_layers',
}
# Note that the defaults for these attributes are equivalent to the base GPT2 model.
def __init__(
self,
d_model=768,
n_heads=12,
n_layers=12,
mlp_ratio=4,
mlp_hidden_size=None,
activation_type="swiglu",
block_type="sequential",
block_group_size=1,
alibi=False,
alibi_bias_max=8.0,
rope=False,
rope_full_precision=True,
multi_query_attention=False,
attention_layer_norm=False,
layer_norm_type="default",
layer_norm_with_affine=True,
attention_layer_norm_with_affine=True,
max_sequence_length=1024,
include_bias=True,
bias_for_layer_norm=None,
scale_logits=False,
vocab_size=50257,
embedding_size=50304,
weight_tying=True,
eos_token_id=50256,
pad_token_id=50256,
**kwargs,
):
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.mlp_ratio = mlp_ratio
self.mlp_hidden_size = mlp_hidden_size
self.activation_type = activation_type
self.block_type = block_type
self.block_group_size = block_group_size
self.alibi = alibi
self.alibi_bias_max = alibi_bias_max
self.rope = rope
self.rope_full_precision = rope_full_precision
self.multi_query_attention = multi_query_attention
self.attention_layer_norm = attention_layer_norm
self.layer_norm_type = layer_norm_type
self.layer_norm_with_affine = layer_norm_with_affine
self.attention_layer_norm_with_affine = attention_layer_norm_with_affine
self.max_sequence_length = max_sequence_length
self.include_bias = include_bias
self.bias_for_layer_norm = bias_for_layer_norm
self.scale_logits = scale_logits
self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.weight_tying = weight_tying
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
super().__init__(**kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment