Unverified Commit ba0bfd40 authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

TP/quantization/weight loading refactor part 1 - Simplify parallel linear logic (#1181)

parent 84e4e37d
......@@ -35,8 +35,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
......@@ -85,14 +86,12 @@ class BloomAttention(nn.Module):
3 * self.hidden_size,
bias=True,
gather_output=False,
perform_initialization=False,
)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False,
)
# Create the alibi slopes and slice them.
......@@ -129,15 +128,17 @@ class BloomMLP(nn.Module):
def __init__(self, config: BloomConfig):
super().__init__()
hidden_size = config.hidden_size
self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
4 * hidden_size,
gather_output=False,
perform_initialization=False)
self.dense_h_to_4h = ColumnParallelLinear(
hidden_size,
4 * hidden_size,
gather_output=False,
)
self.act = get_act_fn("gelu")
self.dense_4h_to_h = RowParallelLinear(4 * hidden_size,
hidden_size,
input_is_parallel=True,
perform_initialization=False)
self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size,
hidden_size,
input_is_parallel=True,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.dense_h_to_4h(x)
......@@ -208,7 +209,9 @@ class BloomModel(nn.Module):
# Embedding + LN Embedding
self.word_embeddings = VocabParallelEmbedding(
config.vocab_size, self.embed_dim, perform_initialization=False)
config.vocab_size,
self.embed_dim,
)
self.word_embeddings_layernorm = nn.LayerNorm(
self.embed_dim, eps=config.layer_norm_epsilon)
......
......@@ -36,9 +36,11 @@ from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear,
reduce_from_tensor_model_parallel_region)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import RWConfig
......@@ -109,7 +111,6 @@ class FalconAttention(nn.Module):
self.head_dim,
bias=config.bias,
gather_output=False,
perform_initialization=False,
skip_bias_add=True,
)
elif self.multi_query:
......@@ -120,7 +121,6 @@ class FalconAttention(nn.Module):
self.total_num_heads * self.head_dim,
bias=config.bias,
gather_output=False,
perform_initialization=False,
skip_bias_add=True,
)
self.key_value = FalconLinear(self.hidden_size,
......@@ -135,7 +135,6 @@ class FalconAttention(nn.Module):
self.head_dim,
bias=config.bias,
gather_output=False,
perform_initialization=False,
skip_bias_add=True,
)
......@@ -151,7 +150,6 @@ class FalconAttention(nn.Module):
self.hidden_size,
bias=config.bias,
input_is_parallel=True,
perform_initialization=False,
skip_bias_add=True,
reduce_results=self.reduce_row_parallel_results)
......@@ -231,7 +229,6 @@ class FalconMLP(nn.Module):
4 * hidden_size,
bias=config.bias,
gather_output=False,
perform_initialization=False,
skip_bias_add=True)
self.act = nn.GELU()
self.reduce_row_parallel_results = not (config.new_decoder_architecture
......@@ -241,7 +238,6 @@ class FalconMLP(nn.Module):
hidden_size,
bias=config.bias,
input_is_parallel=True,
perform_initialization=False,
skip_bias_add=True,
reduce_results=self.reduce_row_parallel_results)
......@@ -325,7 +321,7 @@ class FalconDecoderLayer(nn.Module):
# only one all-reduce operator to reduce the results from
# both MLP and Attention layers.
mlp_output += attention_output
mlp_output = reduce_from_tensor_model_parallel_region(mlp_output)
mlp_output = tensor_model_parallel_all_reduce(mlp_output)
if attention_bias is not None:
mlp_output += attention_bias
if mlp_bias is not None:
......@@ -347,7 +343,9 @@ class FalconModel(nn.Module):
# Embedding + LN Embedding
self.word_embeddings = VocabParallelEmbedding(
config.vocab_size, self.embed_dim, perform_initialization=False)
config.vocab_size,
self.embed_dim,
)
# Transformer blocks
self.h = nn.ModuleList([
......@@ -389,11 +387,12 @@ class FalconForCausalLM(nn.Module):
super().__init__()
self.config = config
self.transformer = FalconModel(config)
self.lm_head = ColumnParallelLinear(config.hidden_size,
config.vocab_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.lm_head = ColumnParallelLinear(
config.hidden_size,
config.vocab_size,
bias=False,
gather_output=False,
)
self.sampler = Sampler(config.vocab_size)
def forward(
......
......@@ -36,8 +36,9 @@ from vllm.model_executor.weight_utils import (
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
......@@ -56,16 +57,18 @@ class GPT2Attention(nn.Module):
self.head_dim = self.hidden_size // total_num_heads
self.scale = self.head_dim**-0.5
self.c_attn = ColumnParallelLinear(self.hidden_size,
3 * self.hidden_size,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_proj = RowParallelLinear(self.hidden_size,
self.hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False)
self.c_attn = ColumnParallelLinear(
self.hidden_size,
3 * self.hidden_size,
bias=True,
gather_output=False,
)
self.c_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
input_is_parallel=True,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scale)
......@@ -95,16 +98,18 @@ class GPT2MLP(nn.Module):
):
super().__init__()
hidden_size = config.hidden_size
self.c_fc = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False)
self.c_fc = ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=True,
gather_output=False,
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=True,
input_is_parallel=True,
)
self.act = get_act_fn(config.activation_function)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......
......@@ -37,8 +37,9 @@ from vllm.model_executor.weight_utils import (
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
......@@ -62,29 +63,31 @@ class GPTBigCodeAttention(nn.Module):
if self.multi_query:
self.num_kv_heads = 1
self.kv_dim = self.head_dim
self.c_attn_q = ColumnParallelLinear(self.hidden_size,
self.hidden_size,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_attn_q = ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
gather_output=False,
)
self.c_attn_kv = nn.Linear(self.hidden_size,
2 * self.kv_dim,
bias=True)
else:
self.num_kv_heads = self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim
self.c_attn = ColumnParallelLinear(self.hidden_size,
self.hidden_size +
2 * self.kv_dim,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_proj = RowParallelLinear(self.hidden_size,
self.hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False)
self.c_attn = ColumnParallelLinear(
self.hidden_size,
self.hidden_size + 2 * self.kv_dim,
bias=True,
gather_output=False,
)
self.c_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
input_is_parallel=True,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scale,
......@@ -124,16 +127,18 @@ class GPTBigMLP(nn.Module):
):
super().__init__()
hidden_size = config.hidden_size
self.c_fc = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False)
self.c_fc = ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=True,
gather_output=False,
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=True,
input_is_parallel=True,
)
self.act = get_act_fn(config.activation_function)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......
......@@ -34,8 +34,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
......@@ -49,16 +50,18 @@ class GPTJAttention(nn.Module):
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads
self.qkv_proj = ColumnParallelLinear(config.hidden_size,
3 * config.hidden_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.out_proj = RowParallelLinear(config.hidden_size,
config.hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False)
self.qkv_proj = ColumnParallelLinear(
config.hidden_size,
3 * config.hidden_size,
bias=False,
gather_output=False,
)
self.out_proj = RowParallelLinear(
config.hidden_size,
config.hidden_size,
bias=False,
input_is_parallel=True,
)
tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
......@@ -102,14 +105,16 @@ class GPTJMLP(nn.Module):
def __init__(self, intermediate_size: int, config: GPTJConfig):
super().__init__()
hidden_size = config.n_embd
self.fc_in = ColumnParallelLinear(hidden_size,
intermediate_size,
gather_output=False,
perform_initialization=False)
self.fc_out = RowParallelLinear(intermediate_size,
hidden_size,
input_is_parallel=True,
perform_initialization=False)
self.fc_in = ColumnParallelLinear(
hidden_size,
intermediate_size,
gather_output=False,
)
self.fc_out = RowParallelLinear(
intermediate_size,
hidden_size,
input_is_parallel=True,
)
self.act = get_act_fn(config.activation_function)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......@@ -159,9 +164,10 @@ class GPTJModel(nn.Module):
super().__init__()
self.config = config
self.embed_dim = config.n_embd
self.wte = VocabParallelEmbedding(config.vocab_size,
self.embed_dim,
perform_initialization=False)
self.wte = VocabParallelEmbedding(
config.vocab_size,
self.embed_dim,
)
self.h = nn.ModuleList(
[GPTJBlock(config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
......@@ -199,10 +205,11 @@ class GPTJForCausalLM(nn.Module):
self.config = config
assert not config.tie_word_embeddings
self.transformer = GPTJModel(config)
self.lm_head = ColumnParallelLinear(config.n_embd,
config.vocab_size,
gather_output=False,
perform_initialization=False)
self.lm_head = ColumnParallelLinear(
config.n_embd,
config.vocab_size,
gather_output=False,
)
self.sampler = Sampler(config.vocab_size)
def forward(
......
......@@ -34,8 +34,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
......@@ -59,11 +60,12 @@ class GPTNeoXAttention(nn.Module):
config.hidden_size,
3 * config.hidden_size,
gather_output=False,
perform_initialization=False)
self.dense = RowParallelLinear(config.hidden_size,
config.hidden_size,
input_is_parallel=True,
perform_initialization=False)
)
self.dense = RowParallelLinear(
config.hidden_size,
config.hidden_size,
input_is_parallel=True,
)
scaling = self.head_size**-0.5
rotary_dim = int(self.head_size * config.rotary_pct)
......@@ -100,14 +102,16 @@ class GPTNeoXMLP(nn.Module):
def __init__(self, config: GPTNeoXConfig):
super().__init__()
self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
gather_output=False,
perform_initialization=False)
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
config.hidden_size,
input_is_parallel=True,
perform_initialization=False)
self.dense_h_to_4h = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
gather_output=False,
)
self.dense_4h_to_h = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
input_is_parallel=True,
)
self.act = get_act_fn(config.hidden_act)
def forward(self, hidden_states):
......@@ -169,9 +173,10 @@ class GPTNeoXModel(nn.Module):
super().__init__()
self.config = config
self.embed_in = VocabParallelEmbedding(config.vocab_size,
config.hidden_size,
perform_initialization=False)
self.embed_in = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList(
[GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
......@@ -209,11 +214,12 @@ class GPTNeoXForCausalLM(nn.Module):
super().__init__()
self.config = config
self.gpt_neox = GPTNeoXModel(config)
self.embed_out = ColumnParallelLinear(config.hidden_size,
config.vocab_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.embed_out = ColumnParallelLinear(
config.hidden_size,
config.vocab_size,
bias=False,
gather_output=False,
)
self.sampler = Sampler(config.vocab_size)
def forward(
......
......@@ -12,8 +12,9 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding)
from vllm.model_executor.weight_utils import (
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights)
......@@ -31,16 +32,18 @@ class InternLMMLP(nn.Module):
hidden_act: str,
):
super().__init__()
self.gate_up_proj = ColumnParallelLinear(hidden_size,
2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False)
self.gate_up_proj = ColumnParallelLinear(
hidden_size,
2 * intermediate_size,
bias=False,
gather_output=False,
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
......@@ -80,14 +83,12 @@ class InternLMAttention(nn.Module):
3 * self.total_num_heads * self.head_dim,
bias=True,
gather_output=False,
perform_initialization=False,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False,
)
self.attn = PagedAttentionWithRoPE(
self.num_heads,
......@@ -176,7 +177,9 @@ class InternLMModel(nn.Module):
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.embed_tokens = VocabParallelEmbedding(
vocab_size, config.hidden_size, perform_initialization=False)
vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
InternLMDecoderLayer(config)
for _ in range(config.num_hidden_layers)
......@@ -216,11 +219,12 @@ class InternLMForCausalLM(nn.Module):
self.config = config
self.model = InternLMModel(config)
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ColumnParallelLinear(config.hidden_size,
vocab_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.lm_head = ColumnParallelLinear(
config.hidden_size,
vocab_size,
bias=False,
gather_output=False,
)
self.sampler = Sampler(config.vocab_size)
def forward(
......
......@@ -39,8 +39,7 @@ from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.quantized_linear import ParallelLinear
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding
from vllm.model_executor.quantization_utils import QuantizationConfig
from vllm.model_executor.weight_utils import (
convert_pyslice_to_tensor, hf_model_weights_iterator,
......@@ -64,13 +63,11 @@ class LlamaMLP(nn.Module):
2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False,
quant_config=quant_config)
self.down_proj = ParallelLinear.row(intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
......@@ -127,7 +124,6 @@ class LlamaAttention(nn.Module):
self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
quant_config=quant_config,
)
self.o_proj = ParallelLinear.row(
......@@ -135,7 +131,6 @@ class LlamaAttention(nn.Module):
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
quant_config=quant_config,
)
self.attn = PagedAttentionWithRoPE(
......@@ -241,7 +236,9 @@ class LlamaModel(nn.Module):
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.embed_tokens = VocabParallelEmbedding(
vocab_size, config.hidden_size, perform_initialization=False)
vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
LlamaDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
......@@ -291,7 +288,6 @@ class LlamaForCausalLM(nn.Module):
vocab_size,
bias=False,
gather_output=False,
perform_initialization=False,
quant_config=None)
self.sampler = Sampler(config.vocab_size)
......
......@@ -38,8 +38,7 @@ from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.quantized_linear import ParallelLinear
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding
from vllm.model_executor.quantization_utils import QuantizationConfig
from vllm.model_executor.weight_utils import (
convert_pyslice_to_tensor, hf_model_weights_iterator,
......@@ -64,13 +63,11 @@ class MistralMLP(nn.Module):
2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False,
quant_config=quant_config)
self.down_proj = ParallelLinear.row(intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
......@@ -116,7 +113,6 @@ class MistralAttention(nn.Module):
self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
quant_config=quant_config,
)
self.o_proj = ParallelLinear.row(
......@@ -124,7 +120,6 @@ class MistralAttention(nn.Module):
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
quant_config=quant_config,
)
self.attn = PagedAttentionWithRoPE(self.num_heads,
......@@ -225,7 +220,9 @@ class MistralModel(nn.Module):
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.embed_tokens = VocabParallelEmbedding(
vocab_size, config.hidden_size, perform_initialization=False)
vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
MistralDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
......@@ -275,7 +272,6 @@ class MistralForCausalLM(nn.Module):
vocab_size,
bias=False,
gather_output=False,
perform_initialization=False,
quant_config=None)
self.sampler = Sampler(config.vocab_size)
......
......@@ -15,8 +15,9 @@ from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.mpt import MPTConfig
......@@ -53,7 +54,6 @@ class MPTAttention(nn.Module):
3 * self.d_model,
bias=not config.no_bias,
gather_output=False,
perform_initialization=False,
)
if self.qk_ln:
self.q_ln = nn.LayerNorm(self.d_model)
......@@ -63,7 +63,6 @@ class MPTAttention(nn.Module):
self.d_model,
bias=not config.no_bias,
input_is_parallel=True,
perform_initialization=False,
)
tp_world_size = get_tensor_model_parallel_world_size()
......@@ -113,17 +112,19 @@ class MPTMLP(nn.Module):
hidden_size = config.d_model
expansion_ratio = config.expansion_ratio
intermediate_size = expansion_ratio * hidden_size
self.up_proj = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=not config.no_bias,
gather_output=False,
perform_initialization=False)
self.up_proj = ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=not config.no_bias,
gather_output=False,
)
self.act = get_act_fn("gelu")
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=not config.no_bias,
input_is_parallel=True,
perform_initialization=False)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=not config.no_bias,
input_is_parallel=True,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.up_proj(x)
......@@ -172,9 +173,10 @@ class MPTModel(nn.Module):
assert config.embedding_fraction == 1.0
assert config.norm_type == "low_precision_layernorm"
self.wte = VocabParallelEmbedding(config.vocab_size,
config.d_model,
perform_initialization=False)
self.wte = VocabParallelEmbedding(
config.vocab_size,
config.d_model,
)
self.blocks = nn.ModuleList(
[MPTBlock(config) for _ in range(config.n_layers)])
self.norm_f = nn.LayerNorm(config.d_model)
......
......@@ -35,8 +35,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
......@@ -73,16 +74,18 @@ class OPTAttention(nn.Module):
self.head_dim = embed_dim // total_num_heads
self.scaling = self.head_dim**-0.5
self.qkv_proj = ColumnParallelLinear(embed_dim,
3 * embed_dim,
bias=bias,
gather_output=False,
perform_initialization=False)
self.out_proj = RowParallelLinear(embed_dim,
embed_dim,
bias=bias,
input_is_parallel=True,
perform_initialization=False)
self.qkv_proj = ColumnParallelLinear(
embed_dim,
3 * embed_dim,
bias=bias,
gather_output=False,
)
self.out_proj = RowParallelLinear(
embed_dim,
embed_dim,
bias=bias,
input_is_parallel=True,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scaling)
......@@ -120,16 +123,18 @@ class OPTDecoderLayer(nn.Module):
self.self_attn_layer_norm = nn.LayerNorm(
self.embed_dim,
elementwise_affine=config.layer_norm_elementwise_affine)
self.fc1 = ColumnParallelLinear(self.embed_dim,
config.ffn_dim,
bias=config.enable_bias,
gather_output=False,
perform_initialization=False)
self.fc2 = RowParallelLinear(config.ffn_dim,
self.embed_dim,
bias=config.enable_bias,
input_is_parallel=True,
perform_initialization=False)
self.fc1 = ColumnParallelLinear(
self.embed_dim,
config.ffn_dim,
bias=config.enable_bias,
gather_output=False,
)
self.fc2 = RowParallelLinear(
config.ffn_dim,
self.embed_dim,
bias=config.enable_bias,
input_is_parallel=True,
)
self.final_layer_norm = nn.LayerNorm(
self.embed_dim,
elementwise_affine=config.layer_norm_elementwise_affine)
......@@ -182,7 +187,7 @@ class OPTDecoder(nn.Module):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.word_embed_proj_dim,
perform_initialization=False)
)
# Positional embeddings are replicated (not sharded).
self.embed_positions = OPTLearnedPositionalEmbedding(
config.max_position_embeddings, config.hidden_size)
......
......@@ -28,7 +28,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.parallel_utils.tensor_parallel import (
from vllm.model_executor.parallel_utils.layers import (
VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear,
......@@ -53,14 +53,12 @@ class QWenMLP(nn.Module):
2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
......@@ -98,14 +96,12 @@ class QWenAttention(nn.Module):
3 * hidden_size,
bias=True,
gather_output=False,
perform_initialization=False,
)
self.c_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
)
self.scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithRoPE(
......@@ -190,9 +186,10 @@ class QWenModel(nn.Module):
self.vocab_size = config.vocab_size
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.wte = VocabParallelEmbedding(vocab_size,
config.hidden_size,
perform_initialization=False)
self.wte = VocabParallelEmbedding(
vocab_size,
config.hidden_size,
)
self.h = nn.ModuleList(
[QWenBlock(config) for _ in range(config.num_hidden_layers)])
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
......@@ -235,7 +232,6 @@ class QWenLMHeadModel(nn.Module):
vocab_size,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.sampler = Sampler(config.vocab_size)
......
import vllm.model_executor.parallel_utils.parallel_state
import vllm.model_executor.parallel_utils.tensor_parallel
__all__ = [
"parallel_state",
"tensor_parallel",
]
import torch
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_group,
)
def tensor_model_parallel_all_reduce(input_):
"""All-reduce the input tensor across model parallel group.
Note: This operation is applied in-place on the input tensor.
"""
# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size() == 1:
return input_
# All-reduce.
torch.distributed.all_reduce(input_,
group=get_tensor_model_parallel_group())
return input_
def tensor_model_parallel_all_gather(input_, dim=-1):
"""All-gather the input tensor across model parallel group."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# Allocate output tensor.
output_tensor = torch.empty((world_size, ) + input_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
torch.distributed.all_gather_into_tensor(
output_tensor, input_, group=get_tensor_model_parallel_group())
# Reshape
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(world_size * input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
# Copyright 2023 The vLLM team.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Parts of the code here are adapted from PyTorch
......@@ -8,60 +9,22 @@ from typing import Optional
import torch
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from .mappings import (
gather_from_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region,
scatter_to_tensor_model_parallel_region,
)
from vllm.model_executor.quantization_utils import QuantizationConfig
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather)
from .utils import (
from vllm.model_executor.parallel_utils.utils import (
divide,
VocabUtility,
split_tensor_along_last_dim,
)
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
'partition_dim': -1,
'partition_stride': 1}
def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, 'tensor_model_parallel') and
param.tensor_model_parallel) or (
get_tensor_model_parallel_rank() == 0)
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
# Make sure the attributes are not set.
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
assert not hasattr(tensor, attribute)
# Set the attributes.
setattr(tensor, 'tensor_model_parallel', is_parallel)
setattr(tensor, 'partition_dim', dim)
setattr(tensor, 'partition_stride', stride)
def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
def maybe_set(attribute, value):
if not hasattr(tensor, attribute):
setattr(tensor, attribute, value)
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute])
def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
def maybe_copy(attribute):
if hasattr(source_tensor, attribute):
setattr(destination_tensor, attribute,
getattr(source_tensor, attribute))
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
maybe_copy(attribute)
class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
......@@ -71,22 +34,14 @@ class VocabParallelEmbedding(torch.nn.Module):
Arguments:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
Keyword Arguments:
init_method: method to initialize weights.
params_dtype
use_cpu_initialization
perform_initialization
params_dtype: type of the parameters.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, *,
init_method=init.xavier_normal_,
params_dtype: torch.dtype=None,
use_cpu_initialization: bool=False,
perform_initialization: bool=False):
super(VocabParallelEmbedding, self).__init__()
assert not perform_initialization
assert not use_cpu_initialization
def __init__(self,
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None):
super().__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
......@@ -94,46 +49,39 @@ class VocabParallelEmbedding(torch.nn.Module):
if params_dtype is None:
params_dtype = torch.get_default_dtype()
# Set the defaults for compatibility.
self.padding_idx = None
self.max_norm = None
self.norm_type = 2.
self.scale_grad_by_freq = False
self.sparse = False
self._weight = None
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
self.tp_size = get_tensor_model_parallel_world_size()
# TODO: Handle vocab padding here.
# Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = \
self.vocab_start_index, self.vocab_end_index = (
VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_tensor_model_parallel_rank(),
self.tensor_model_parallel_size)
self.num_embeddings_per_partition = self.vocab_end_index - \
self.vocab_start_index
self.tp_size))
self.num_embeddings_per_partition = (self.vocab_end_index -
self.vocab_start_index)
self.weight = Parameter(torch.empty(
self.num_embeddings_per_partition, self.embedding_dim,
device=torch.cuda.current_device(), dtype=params_dtype))
self.weight = Parameter(
torch.empty(self.num_embeddings_per_partition,
self.embedding_dim,
device=torch.cuda.current_device(),
dtype=params_dtype))
def forward(self, input_):
if self.tensor_model_parallel_size > 1:
if self.tp_size > 1:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | \
(input_ >= self.vocab_end_index)
input_mask = ((input_ < self.vocab_start_index) |
(input_ >= self.vocab_end_index))
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
output_parallel = F.embedding(masked_input, self.weight,
self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq,
self.sparse)
output_parallel = F.embedding(masked_input, self.weight)
# Mask the output embedding.
if self.tensor_model_parallel_size > 1:
if self.tp_size > 1:
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = reduce_from_tensor_model_parallel_region(output_parallel)
output = tensor_model_parallel_all_reduce(output_parallel)
return output
......@@ -152,40 +100,32 @@ class ColumnParallelLinear(torch.nn.Module):
gather_output: If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
adding bias but instead return it.
params_dtype:
use_cpu_initialization:
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configuration.
"""
def __init__(self, input_size, output_size, *,
bias=True, gather_output=True,
init_method=init.xavier_normal_, stride=1,
keep_master_weight_for_test=False,
skip_bias_add=False,
params_dtype=None,
use_cpu_initialization=False,
perform_initialization=False,
quant_config=None,
):
super(ColumnParallelLinear, self).__init__()
assert not perform_initialization
assert not use_cpu_initialization
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
gather_output: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.gather_output = gather_output
# Divide the weight matrix along the last dimension.
self.world_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, self.world_size)
self.tp_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, self.tp_size)
self.skip_bias_add = skip_bias_add
self.quant_config = quant_config
......@@ -198,21 +138,19 @@ class ColumnParallelLinear(torch.nn.Module):
self.create_weights(params_dtype)
if bias:
self.bias = Parameter(torch.empty(
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype))
set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype))
else:
self.register_parameter('bias', None)
def create_weights(self, dtype: torch.dtype) -> None:
self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.input_size,
device=torch.cuda.current_device(), dtype=dtype))
self.weight = Parameter(
torch.empty(self.output_size_per_partition,
self.input_size,
device=torch.cuda.current_device(),
dtype=dtype))
def apply_weights(
self,
......@@ -225,7 +163,7 @@ class ColumnParallelLinear(torch.nn.Module):
"""Forward of ColumnParallelLinear
Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
input_: Tensor whose last dimension is `input_size`.
Returns:
- output
......@@ -238,7 +176,7 @@ class ColumnParallelLinear(torch.nn.Module):
output_parallel = self.apply_weights(input_parallel, bias)
if self.gather_output:
# All-gather across the partitions.
output = gather_from_tensor_model_parallel_region(output_parallel)
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
......@@ -266,36 +204,25 @@ class RowParallelLinear(torch.nn.Module):
input_is_parallel: If true, we assume that the input is already
split across the GPUs and we do not split
again.
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimization where bias
can be fused with other elementwise operations. We skip
adding bias but instead return it.
params_dtype:
use_cpu_initialization:
perform_initialization:
reduce_results:
skip_bias_add: This was added to enable performance optimization where
bias can be fused with other element-wise operations.
We skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configuration.
"""
def __init__(self, input_size, output_size, *,
bias=True, input_is_parallel=False,
init_method=init.xavier_normal_, stride=1,
keep_master_weight_for_test=False,
skip_bias_add=False,
params_dtype=None,
use_cpu_initialization=False,
perform_initialization=False,
reduce_results=True,
quant_config=None,
):
super(RowParallelLinear, self).__init__()
assert not perform_initialization
assert not use_cpu_initialization
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
......@@ -305,21 +232,22 @@ class RowParallelLinear(torch.nn.Module):
params_dtype = torch.get_default_dtype()
# Divide the weight matrix along the last dimension.
self.world_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.world_size)
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
self.skip_bias_add = skip_bias_add
self.quant_config = quant_config
self.create_weights(params_dtype)
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
raise ValueError('When not reduce the results, adding bias to the '
'results can lead to incorrect results')
if bias:
self.bias = Parameter(torch.empty(
self.output_size, device=torch.cuda.current_device(),
dtype=params_dtype))
self.bias = Parameter(
torch.empty(self.output_size,
device=torch.cuda.current_device(),
dtype=params_dtype))
# Always initialize bias to zero.
with torch.no_grad():
......@@ -328,9 +256,11 @@ class RowParallelLinear(torch.nn.Module):
self.register_parameter('bias', None)
def create_weights(self, dtype: torch.dtype) -> None:
self.weight = Parameter(torch.empty(
self.output_size, self.input_size_per_partition,
device=torch.cuda.current_device(), dtype=dtype))
self.weight = Parameter(
torch.empty(self.output_size,
self.input_size_per_partition,
device=torch.cuda.current_device(),
dtype=dtype))
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight)
......@@ -339,7 +269,9 @@ class RowParallelLinear(torch.nn.Module):
"""Forward of RowParallelLinear
Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
input_: tensor whose last dimension is `input_size`. If
`input_is_parallel` is set, then the last dimension
is `input_size // tp_size`.
Returns:
- output
......@@ -349,11 +281,16 @@ class RowParallelLinear(torch.nn.Module):
if self.input_is_parallel:
input_parallel = input_
else:
input_parallel = scatter_to_tensor_model_parallel_region(input_)
# TODO: simplify code below
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
output_parallel = self.apply_weights(input_parallel)
if self.reduce_results and self.world_size > 1:
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if self.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
output_ = output_parallel
......
from .layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
set_tensor_model_parallel_attributes,
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes,
param_is_not_tensor_parallel_duplicate,
)
from .mappings import (
copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region,
gather_from_sequence_parallel_region,
reduce_from_tensor_model_parallel_region,
scatter_to_tensor_model_parallel_region,
scatter_to_sequence_parallel_region,
)
from .random import (
get_cuda_rng_tracker,
model_parallel_cuda_manual_seed,
)
from .utils import (
split_tensor_along_last_dim,
)
__all__ = [
#layers.py
"ColumnParallelLinear",
"RowParallelLinear",
"VocabParallelEmbedding",
"set_tensor_model_parallel_attributes",
"set_defaults_if_not_set_tensor_model_parallel_attributes",
"copy_tensor_model_parallel_attributes",
"param_is_not_tensor_parallel_duplicate",
# mappings.py
"copy_to_tensor_model_parallel_region",
"gather_from_tensor_model_parallel_region",
"gather_from_sequence_parallel_region",
"reduce_from_tensor_model_parallel_region",
"scatter_to_tensor_model_parallel_region",
"scatter_to_sequence_parallel_region",
# random.py
"get_cuda_rng_tracker",
"model_parallel_cuda_manual_seed",
# utils.py
"split_tensor_along_last_dim",
]
# Copyright 2023 The vLLM team.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/mappings.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_group,
)
from .utils import split_tensor_along_last_dim
def _reduce(input_):
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size()==1:
return input_
# All-reduce.
torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())
return input_
def _split_along_last_dim(input_):
"""Split the tensor along its last dimension and keep the
corresponding slice."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# Split along last dimension.
input_list = split_tensor_along_last_dim(input_, world_size)
# Note: torch.split does not create contiguous tensors by default.
rank = get_tensor_model_parallel_rank()
output = input_list[rank].contiguous()
return output
def _split_along_first_dim(input_):
"""Split the tensor along its first dimension and keep the
corresponding slice."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# Split along first dimension.
dim_size = input_.size()[0]
assert dim_size % world_size == 0, \
"First dimension of the tensor should be divisible by tensor parallel size"
local_dim_size = dim_size // world_size
rank = get_tensor_model_parallel_rank()
dim_offset = rank * local_dim_size
output = input_[dim_offset:dim_offset+local_dim_size].contiguous()
return output
def _gather_along_last_dim(input_):
"""Gather tensors and concatinate along the last dimension."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# Size and dimension.
last_dim = input_.dim() - 1
rank = get_tensor_model_parallel_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=last_dim).contiguous()
return output
def _gather_along_first_dim(input_):
"""Gather tensors and concatinate along the first dimension."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
dim_size[0] = dim_size[0] * world_size
output = torch.empty(dim_size, dtype=input_.dtype,
device=torch.cuda.current_device())
torch.distributed._all_gather_base(output, input_.contiguous(),
group=get_tensor_model_parallel_group())
return output
def _reduce_scatter_along_first_dim(input_):
"""Reduce-scatter the input tensor across model parallel group."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
assert dim_size[0] % world_size == 0, \
"First dimension of the tensor should be divisible by tensor parallel size"
dim_size[0] = dim_size[0] // world_size
output = torch.empty(dim_size, dtype=input_.dtype,
device=torch.cuda.current_device())
torch.distributed._reduce_scatter_base(output, input_.contiguous(),
group=get_tensor_model_parallel_group())
return output
class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod
def forward(ctx, input_):
return input_
@staticmethod
def backward(ctx, grad_output):
return _reduce(grad_output)
class _ReduceFromModelParallelRegion(torch.autograd.Function):
"""All-reduce the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce(input_)
@staticmethod
def forward(ctx, input_):
return _reduce(input_)
@staticmethod
def backward(ctx, grad_output):
return grad_output
class _ScatterToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_):
return _split_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
return _split_along_last_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_last_dim(grad_output)
class _GatherFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_):
return _gather_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
return _gather_along_last_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _split_along_last_dim(grad_output)
class _ScatterToSequenceParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_):
return _split_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
return _split_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_first_dim(grad_output)
class _GatherFromSequenceParallelRegion(torch.autograd.Function):
"""Gather the input from sequence parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_, tensor_parallel_output_grad=True):
return _gather_along_first_dim(input_)
@staticmethod
def forward(ctx, input_, tensor_parallel_output_grad=True):
ctx.tensor_parallel_output_grad = tensor_parallel_output_grad
return _gather_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
tensor_parallel_output_grad = ctx.tensor_parallel_output_grad
# If the computation graph after the gather operation is
# in the tensor parallel mode, output gradients need to reduce
# scattered and whereas if the computation is duplicated,
# output gradients need to be scattered.
if tensor_parallel_output_grad:
return _reduce_scatter_along_first_dim(grad_output), None
else:
return _split_along_first_dim(grad_output), None
class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
"""Reduce scatter the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce_scatter_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
return _reduce_scatter_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_first_dim(grad_output)
# -----------------
# Helper functions.
# -----------------
def copy_to_tensor_model_parallel_region(input_):
return _CopyToModelParallelRegion.apply(input_)
def reduce_from_tensor_model_parallel_region(input_):
return _ReduceFromModelParallelRegion.apply(input_)
def scatter_to_tensor_model_parallel_region(input_):
return _ScatterToModelParallelRegion.apply(input_)
def gather_from_tensor_model_parallel_region(input_):
return _GatherFromModelParallelRegion.apply(input_)
def scatter_to_sequence_parallel_region(input_):
return _ScatterToSequenceParallelRegion.apply(input_)
def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True):
return _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad)
def reduce_scatter_to_sequence_parallel_region(input_):
return _ReduceScatterToSequenceParallelRegion.apply(input_)
# Copyright 2023 The vLLM team.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/random.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
import contextlib
import torch
from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
)
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
def _set_cuda_rng_state(new_state, device=-1):
"""Sets the random number generator state of the current GPU.
Argumentss:
new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases.
"""
if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):
# older PyTorch
def cb():
with device_ctx_manager(device):
_C._cuda_setRNGState(new_state)
else:
# newer PyTorch
if device == -1:
device = torch.device('cuda')
elif isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device('cuda', device)
def cb():
idx = device.index
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
default_generator.set_state(new_state)
_lazy_call(cb)
class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states.
Using the `add` method, a cuda rng state is initialized based on
the input `seed` and is assigned to `name`. Later, by forking the
rng state, we can perform operations and return to our starting
cuda state.
"""
def __init__(self):
# Map from a string name to the cuda rng state.
self.states_ = {}
# Seeds are just for book keeping and ensure no seed is set twice.
self.seeds_ = set()
def reset(self):
"""Set to the initial state (no tracker)."""
self.states_ = {}
self.seeds_ = set()
def get_states(self):
"""Get rng states. Copy the dictionary so we have direct
pointers to the states, not just a pointer to the dictionary."""
states = {}
for name in self.states_:
states[name] = self.states_[name]
return states
def set_states(self, states):
"""Set the rng states. For efficiency purposes, we do not check
the size of seed for compatibility."""
self.states_ = states
def add(self, name, seed):
"""Track the rng state."""
# Check seed is not already used.
if seed in self.seeds_:
raise Exception('seed {} already exists'.format(seed))
self.seeds_.add(seed)
# Check that state is not already defined.
if name in self.states_:
raise Exception('cuda rng state {} already exists'.format(name))
# Get the current rng state.
orig_rng_state = torch.cuda.get_rng_state()
# Set the new state and store it.
torch.cuda.manual_seed(seed)
self.states_[name] = torch.cuda.get_rng_state()
# Reset rng state to what it was.
_set_cuda_rng_state(orig_rng_state)
@contextlib.contextmanager
def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
"""Fork the cuda rng state, perform operations, and exit with
the original state."""
# Check if we have added the state
if name not in self.states_:
raise Exception('cuda rng state {} is not added'.format(name))
# Store current rng state.
orig_cuda_rng_state = torch.cuda.get_rng_state()
# Set rng state to the desired one
_set_cuda_rng_state(self.states_[name])
# Do the stuff we wanted to do.
try:
yield
finally:
# Update the current rng state for later use.
self.states_[name] = torch.cuda.get_rng_state()
# And set the state to the original state we started with.
_set_cuda_rng_state(orig_cuda_rng_state)
# RNG tracker object.
_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
def get_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _CUDA_RNG_STATE_TRACKER
def model_parallel_cuda_manual_seed(seed):
"""Initialize model parallel cuda seed.
This function should be called after the model parallel is
initialized. Also, no torch.cuda.manual_seed should be called
after this function. Basically, this is replacement for that
function.
Two set of RNG states are tracked:
default state: This is for data parallelism and is the same among a
set of model parallel GPUs but different across
different model paralle groups. This is used for
example for dropout in the non-tensor-model-parallel regions.
tensor-model-parallel state: This state is different among a set of model
parallel GPUs, but the same across data parallel
groups. This is used for example for dropout in
model parallel regions.
"""
# 2718 is just for fun and any POSITIVE value will work.
offset = seed + 2718
tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank()
# Data parallel gets the original seed.
data_parallel_seed = seed
_CUDA_RNG_STATE_TRACKER.reset()
# Set the default state.
torch.cuda.manual_seed(data_parallel_seed)
# and model parallel state.
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME,
tensor_model_parallel_seed)
# Copyright 2023 The vLLM team.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from typing import List, Sequence
import torch
from typing import List, Sequence
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(
numerator, denominator
)
numerator, denominator)
def divide(numerator, denominator):
......@@ -56,15 +57,14 @@ class VocabUtility:
@staticmethod
def vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size: int, rank, world_size: int
) -> Sequence[int]:
per_partition_vocab_size: int, rank: int) -> Sequence[int]:
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
@staticmethod
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int) -> Sequence[int]:
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int,
world_size: int) -> Sequence[int]:
per_partition_vocab_size = divide(global_vocab_size, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size, rank, world_size
)
per_partition_vocab_size, rank)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment