Unverified Commit ba0bfd40 authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

TP/quantization/weight loading refactor part 1 - Simplify parallel linear logic (#1181)

parent 84e4e37d
...@@ -35,8 +35,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator, ...@@ -35,8 +35,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights) load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -85,14 +86,12 @@ class BloomAttention(nn.Module): ...@@ -85,14 +86,12 @@ class BloomAttention(nn.Module):
3 * self.hidden_size, 3 * self.hidden_size,
bias=True, bias=True,
gather_output=False, gather_output=False,
perform_initialization=False,
) )
self.dense = RowParallelLinear( self.dense = RowParallelLinear(
self.hidden_size, self.hidden_size,
self.hidden_size, self.hidden_size,
bias=True, bias=True,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False,
) )
# Create the alibi slopes and slice them. # Create the alibi slopes and slice them.
...@@ -129,15 +128,17 @@ class BloomMLP(nn.Module): ...@@ -129,15 +128,17 @@ class BloomMLP(nn.Module):
def __init__(self, config: BloomConfig): def __init__(self, config: BloomConfig):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.dense_h_to_4h = ColumnParallelLinear(hidden_size, self.dense_h_to_4h = ColumnParallelLinear(
4 * hidden_size, hidden_size,
gather_output=False, 4 * hidden_size,
perform_initialization=False) gather_output=False,
)
self.act = get_act_fn("gelu") self.act = get_act_fn("gelu")
self.dense_4h_to_h = RowParallelLinear(4 * hidden_size, self.dense_4h_to_h = RowParallelLinear(
hidden_size, 4 * hidden_size,
input_is_parallel=True, hidden_size,
perform_initialization=False) input_is_parallel=True,
)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.dense_h_to_4h(x) x, _ = self.dense_h_to_4h(x)
...@@ -208,7 +209,9 @@ class BloomModel(nn.Module): ...@@ -208,7 +209,9 @@ class BloomModel(nn.Module):
# Embedding + LN Embedding # Embedding + LN Embedding
self.word_embeddings = VocabParallelEmbedding( self.word_embeddings = VocabParallelEmbedding(
config.vocab_size, self.embed_dim, perform_initialization=False) config.vocab_size,
self.embed_dim,
)
self.word_embeddings_layernorm = nn.LayerNorm( self.word_embeddings_layernorm = nn.LayerNorm(
self.embed_dim, eps=config.layer_norm_epsilon) self.embed_dim, eps=config.layer_norm_epsilon)
......
...@@ -36,9 +36,11 @@ from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor, ...@@ -36,9 +36,11 @@ from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
load_tensor_parallel_weights) load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear, ColumnParallelLinear,
reduce_from_tensor_model_parallel_region) RowParallelLinear)
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import RWConfig from vllm.transformers_utils.configs import RWConfig
...@@ -109,7 +111,6 @@ class FalconAttention(nn.Module): ...@@ -109,7 +111,6 @@ class FalconAttention(nn.Module):
self.head_dim, self.head_dim,
bias=config.bias, bias=config.bias,
gather_output=False, gather_output=False,
perform_initialization=False,
skip_bias_add=True, skip_bias_add=True,
) )
elif self.multi_query: elif self.multi_query:
...@@ -120,7 +121,6 @@ class FalconAttention(nn.Module): ...@@ -120,7 +121,6 @@ class FalconAttention(nn.Module):
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
bias=config.bias, bias=config.bias,
gather_output=False, gather_output=False,
perform_initialization=False,
skip_bias_add=True, skip_bias_add=True,
) )
self.key_value = FalconLinear(self.hidden_size, self.key_value = FalconLinear(self.hidden_size,
...@@ -135,7 +135,6 @@ class FalconAttention(nn.Module): ...@@ -135,7 +135,6 @@ class FalconAttention(nn.Module):
self.head_dim, self.head_dim,
bias=config.bias, bias=config.bias,
gather_output=False, gather_output=False,
perform_initialization=False,
skip_bias_add=True, skip_bias_add=True,
) )
...@@ -151,7 +150,6 @@ class FalconAttention(nn.Module): ...@@ -151,7 +150,6 @@ class FalconAttention(nn.Module):
self.hidden_size, self.hidden_size,
bias=config.bias, bias=config.bias,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False,
skip_bias_add=True, skip_bias_add=True,
reduce_results=self.reduce_row_parallel_results) reduce_results=self.reduce_row_parallel_results)
...@@ -231,7 +229,6 @@ class FalconMLP(nn.Module): ...@@ -231,7 +229,6 @@ class FalconMLP(nn.Module):
4 * hidden_size, 4 * hidden_size,
bias=config.bias, bias=config.bias,
gather_output=False, gather_output=False,
perform_initialization=False,
skip_bias_add=True) skip_bias_add=True)
self.act = nn.GELU() self.act = nn.GELU()
self.reduce_row_parallel_results = not (config.new_decoder_architecture self.reduce_row_parallel_results = not (config.new_decoder_architecture
...@@ -241,7 +238,6 @@ class FalconMLP(nn.Module): ...@@ -241,7 +238,6 @@ class FalconMLP(nn.Module):
hidden_size, hidden_size,
bias=config.bias, bias=config.bias,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False,
skip_bias_add=True, skip_bias_add=True,
reduce_results=self.reduce_row_parallel_results) reduce_results=self.reduce_row_parallel_results)
...@@ -325,7 +321,7 @@ class FalconDecoderLayer(nn.Module): ...@@ -325,7 +321,7 @@ class FalconDecoderLayer(nn.Module):
# only one all-reduce operator to reduce the results from # only one all-reduce operator to reduce the results from
# both MLP and Attention layers. # both MLP and Attention layers.
mlp_output += attention_output mlp_output += attention_output
mlp_output = reduce_from_tensor_model_parallel_region(mlp_output) mlp_output = tensor_model_parallel_all_reduce(mlp_output)
if attention_bias is not None: if attention_bias is not None:
mlp_output += attention_bias mlp_output += attention_bias
if mlp_bias is not None: if mlp_bias is not None:
...@@ -347,7 +343,9 @@ class FalconModel(nn.Module): ...@@ -347,7 +343,9 @@ class FalconModel(nn.Module):
# Embedding + LN Embedding # Embedding + LN Embedding
self.word_embeddings = VocabParallelEmbedding( self.word_embeddings = VocabParallelEmbedding(
config.vocab_size, self.embed_dim, perform_initialization=False) config.vocab_size,
self.embed_dim,
)
# Transformer blocks # Transformer blocks
self.h = nn.ModuleList([ self.h = nn.ModuleList([
...@@ -389,11 +387,12 @@ class FalconForCausalLM(nn.Module): ...@@ -389,11 +387,12 @@ class FalconForCausalLM(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.transformer = FalconModel(config) self.transformer = FalconModel(config)
self.lm_head = ColumnParallelLinear(config.hidden_size, self.lm_head = ColumnParallelLinear(
config.vocab_size, config.hidden_size,
bias=False, config.vocab_size,
gather_output=False, bias=False,
perform_initialization=False) gather_output=False,
)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
def forward( def forward(
......
...@@ -36,8 +36,9 @@ from vllm.model_executor.weight_utils import ( ...@@ -36,8 +36,9 @@ from vllm.model_executor.weight_utils import (
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights) load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -56,16 +57,18 @@ class GPT2Attention(nn.Module): ...@@ -56,16 +57,18 @@ class GPT2Attention(nn.Module):
self.head_dim = self.hidden_size // total_num_heads self.head_dim = self.hidden_size // total_num_heads
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
self.c_attn = ColumnParallelLinear(self.hidden_size, self.c_attn = ColumnParallelLinear(
3 * self.hidden_size, self.hidden_size,
bias=True, 3 * self.hidden_size,
gather_output=False, bias=True,
perform_initialization=False) gather_output=False,
self.c_proj = RowParallelLinear(self.hidden_size, )
self.hidden_size, self.c_proj = RowParallelLinear(
bias=True, self.hidden_size,
input_is_parallel=True, self.hidden_size,
perform_initialization=False) bias=True,
input_is_parallel=True,
)
self.attn = PagedAttention(self.num_heads, self.attn = PagedAttention(self.num_heads,
self.head_dim, self.head_dim,
scale=self.scale) scale=self.scale)
...@@ -95,16 +98,18 @@ class GPT2MLP(nn.Module): ...@@ -95,16 +98,18 @@ class GPT2MLP(nn.Module):
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.c_fc = ColumnParallelLinear(hidden_size, self.c_fc = ColumnParallelLinear(
intermediate_size, hidden_size,
bias=True, intermediate_size,
gather_output=False, bias=True,
perform_initialization=False) gather_output=False,
self.c_proj = RowParallelLinear(intermediate_size, )
hidden_size, self.c_proj = RowParallelLinear(
bias=True, intermediate_size,
input_is_parallel=True, hidden_size,
perform_initialization=False) bias=True,
input_is_parallel=True,
)
self.act = get_act_fn(config.activation_function) self.act = get_act_fn(config.activation_function)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......
...@@ -37,8 +37,9 @@ from vllm.model_executor.weight_utils import ( ...@@ -37,8 +37,9 @@ from vllm.model_executor.weight_utils import (
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights) load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -62,29 +63,31 @@ class GPTBigCodeAttention(nn.Module): ...@@ -62,29 +63,31 @@ class GPTBigCodeAttention(nn.Module):
if self.multi_query: if self.multi_query:
self.num_kv_heads = 1 self.num_kv_heads = 1
self.kv_dim = self.head_dim self.kv_dim = self.head_dim
self.c_attn_q = ColumnParallelLinear(self.hidden_size, self.c_attn_q = ColumnParallelLinear(
self.hidden_size, self.hidden_size,
bias=True, self.hidden_size,
gather_output=False, bias=True,
perform_initialization=False) gather_output=False,
)
self.c_attn_kv = nn.Linear(self.hidden_size, self.c_attn_kv = nn.Linear(self.hidden_size,
2 * self.kv_dim, 2 * self.kv_dim,
bias=True) bias=True)
else: else:
self.num_kv_heads = self.num_heads self.num_kv_heads = self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim self.kv_dim = self.num_kv_heads * self.head_dim
self.c_attn = ColumnParallelLinear(self.hidden_size, self.c_attn = ColumnParallelLinear(
self.hidden_size + self.hidden_size,
2 * self.kv_dim, self.hidden_size + 2 * self.kv_dim,
bias=True, bias=True,
gather_output=False, gather_output=False,
perform_initialization=False) )
self.c_proj = RowParallelLinear(self.hidden_size, self.c_proj = RowParallelLinear(
self.hidden_size, self.hidden_size,
bias=True, self.hidden_size,
input_is_parallel=True, bias=True,
perform_initialization=False) input_is_parallel=True,
)
self.attn = PagedAttention(self.num_heads, self.attn = PagedAttention(self.num_heads,
self.head_dim, self.head_dim,
scale=self.scale, scale=self.scale,
...@@ -124,16 +127,18 @@ class GPTBigMLP(nn.Module): ...@@ -124,16 +127,18 @@ class GPTBigMLP(nn.Module):
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.c_fc = ColumnParallelLinear(hidden_size, self.c_fc = ColumnParallelLinear(
intermediate_size, hidden_size,
bias=True, intermediate_size,
gather_output=False, bias=True,
perform_initialization=False) gather_output=False,
self.c_proj = RowParallelLinear(intermediate_size, )
hidden_size, self.c_proj = RowParallelLinear(
bias=True, intermediate_size,
input_is_parallel=True, hidden_size,
perform_initialization=False) bias=True,
input_is_parallel=True,
)
self.act = get_act_fn(config.activation_function) self.act = get_act_fn(config.activation_function)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......
...@@ -34,8 +34,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator, ...@@ -34,8 +34,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights) load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -49,16 +50,18 @@ class GPTJAttention(nn.Module): ...@@ -49,16 +50,18 @@ class GPTJAttention(nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads self.head_size = self.hidden_size // self.total_num_heads
self.qkv_proj = ColumnParallelLinear(config.hidden_size, self.qkv_proj = ColumnParallelLinear(
3 * config.hidden_size, config.hidden_size,
bias=False, 3 * config.hidden_size,
gather_output=False, bias=False,
perform_initialization=False) gather_output=False,
self.out_proj = RowParallelLinear(config.hidden_size, )
config.hidden_size, self.out_proj = RowParallelLinear(
bias=False, config.hidden_size,
input_is_parallel=True, config.hidden_size,
perform_initialization=False) bias=False,
input_is_parallel=True,
)
tp_world_size = get_tensor_model_parallel_world_size() tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0 assert self.total_num_heads % tp_world_size == 0
...@@ -102,14 +105,16 @@ class GPTJMLP(nn.Module): ...@@ -102,14 +105,16 @@ class GPTJMLP(nn.Module):
def __init__(self, intermediate_size: int, config: GPTJConfig): def __init__(self, intermediate_size: int, config: GPTJConfig):
super().__init__() super().__init__()
hidden_size = config.n_embd hidden_size = config.n_embd
self.fc_in = ColumnParallelLinear(hidden_size, self.fc_in = ColumnParallelLinear(
intermediate_size, hidden_size,
gather_output=False, intermediate_size,
perform_initialization=False) gather_output=False,
self.fc_out = RowParallelLinear(intermediate_size, )
hidden_size, self.fc_out = RowParallelLinear(
input_is_parallel=True, intermediate_size,
perform_initialization=False) hidden_size,
input_is_parallel=True,
)
self.act = get_act_fn(config.activation_function) self.act = get_act_fn(config.activation_function)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
...@@ -159,9 +164,10 @@ class GPTJModel(nn.Module): ...@@ -159,9 +164,10 @@ class GPTJModel(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_dim = config.n_embd self.embed_dim = config.n_embd
self.wte = VocabParallelEmbedding(config.vocab_size, self.wte = VocabParallelEmbedding(
self.embed_dim, config.vocab_size,
perform_initialization=False) self.embed_dim,
)
self.h = nn.ModuleList( self.h = nn.ModuleList(
[GPTJBlock(config) for _ in range(config.n_layer)]) [GPTJBlock(config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
...@@ -199,10 +205,11 @@ class GPTJForCausalLM(nn.Module): ...@@ -199,10 +205,11 @@ class GPTJForCausalLM(nn.Module):
self.config = config self.config = config
assert not config.tie_word_embeddings assert not config.tie_word_embeddings
self.transformer = GPTJModel(config) self.transformer = GPTJModel(config)
self.lm_head = ColumnParallelLinear(config.n_embd, self.lm_head = ColumnParallelLinear(
config.vocab_size, config.n_embd,
gather_output=False, config.vocab_size,
perform_initialization=False) gather_output=False,
)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
def forward( def forward(
......
...@@ -34,8 +34,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator, ...@@ -34,8 +34,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights) load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -59,11 +60,12 @@ class GPTNeoXAttention(nn.Module): ...@@ -59,11 +60,12 @@ class GPTNeoXAttention(nn.Module):
config.hidden_size, config.hidden_size,
3 * config.hidden_size, 3 * config.hidden_size,
gather_output=False, gather_output=False,
perform_initialization=False) )
self.dense = RowParallelLinear(config.hidden_size, self.dense = RowParallelLinear(
config.hidden_size, config.hidden_size,
input_is_parallel=True, config.hidden_size,
perform_initialization=False) input_is_parallel=True,
)
scaling = self.head_size**-0.5 scaling = self.head_size**-0.5
rotary_dim = int(self.head_size * config.rotary_pct) rotary_dim = int(self.head_size * config.rotary_pct)
...@@ -100,14 +102,16 @@ class GPTNeoXMLP(nn.Module): ...@@ -100,14 +102,16 @@ class GPTNeoXMLP(nn.Module):
def __init__(self, config: GPTNeoXConfig): def __init__(self, config: GPTNeoXConfig):
super().__init__() super().__init__()
self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size, self.dense_h_to_4h = ColumnParallelLinear(
config.intermediate_size, config.hidden_size,
gather_output=False, config.intermediate_size,
perform_initialization=False) gather_output=False,
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, )
config.hidden_size, self.dense_4h_to_h = RowParallelLinear(
input_is_parallel=True, config.intermediate_size,
perform_initialization=False) config.hidden_size,
input_is_parallel=True,
)
self.act = get_act_fn(config.hidden_act) self.act = get_act_fn(config.hidden_act)
def forward(self, hidden_states): def forward(self, hidden_states):
...@@ -169,9 +173,10 @@ class GPTNeoXModel(nn.Module): ...@@ -169,9 +173,10 @@ class GPTNeoXModel(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_in = VocabParallelEmbedding(config.vocab_size, self.embed_in = VocabParallelEmbedding(
config.hidden_size, config.vocab_size,
perform_initialization=False) config.hidden_size,
)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)]) [GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size, self.final_layer_norm = nn.LayerNorm(config.hidden_size,
...@@ -209,11 +214,12 @@ class GPTNeoXForCausalLM(nn.Module): ...@@ -209,11 +214,12 @@ class GPTNeoXForCausalLM(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.gpt_neox = GPTNeoXModel(config) self.gpt_neox = GPTNeoXModel(config)
self.embed_out = ColumnParallelLinear(config.hidden_size, self.embed_out = ColumnParallelLinear(
config.vocab_size, config.hidden_size,
bias=False, config.vocab_size,
gather_output=False, bias=False,
perform_initialization=False) gather_output=False,
)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
def forward( def forward(
......
...@@ -12,8 +12,9 @@ from vllm.model_executor.layers.layernorm import RMSNorm ...@@ -12,8 +12,9 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding) RowParallelLinear,
VocabParallelEmbedding)
from vllm.model_executor.weight_utils import ( from vllm.model_executor.weight_utils import (
hf_model_weights_iterator, load_padded_tensor_parallel_vocab, hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights) load_tensor_parallel_weights)
...@@ -31,16 +32,18 @@ class InternLMMLP(nn.Module): ...@@ -31,16 +32,18 @@ class InternLMMLP(nn.Module):
hidden_act: str, hidden_act: str,
): ):
super().__init__() super().__init__()
self.gate_up_proj = ColumnParallelLinear(hidden_size, self.gate_up_proj = ColumnParallelLinear(
2 * intermediate_size, hidden_size,
bias=False, 2 * intermediate_size,
gather_output=False, bias=False,
perform_initialization=False) gather_output=False,
self.down_proj = RowParallelLinear(intermediate_size, )
hidden_size, self.down_proj = RowParallelLinear(
bias=False, intermediate_size,
input_is_parallel=True, hidden_size,
perform_initialization=False) bias=False,
input_is_parallel=True,
)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
...@@ -80,14 +83,12 @@ class InternLMAttention(nn.Module): ...@@ -80,14 +83,12 @@ class InternLMAttention(nn.Module):
3 * self.total_num_heads * self.head_dim, 3 * self.total_num_heads * self.head_dim,
bias=True, bias=True,
gather_output=False, gather_output=False,
perform_initialization=False,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=True, bias=True,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False,
) )
self.attn = PagedAttentionWithRoPE( self.attn = PagedAttentionWithRoPE(
self.num_heads, self.num_heads,
...@@ -176,7 +177,9 @@ class InternLMModel(nn.Module): ...@@ -176,7 +177,9 @@ class InternLMModel(nn.Module):
vocab_size = ((config.vocab_size + 63) // 64) * 64 vocab_size = ((config.vocab_size + 63) // 64) * 64
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
vocab_size, config.hidden_size, perform_initialization=False) vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
InternLMDecoderLayer(config) InternLMDecoderLayer(config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
...@@ -216,11 +219,12 @@ class InternLMForCausalLM(nn.Module): ...@@ -216,11 +219,12 @@ class InternLMForCausalLM(nn.Module):
self.config = config self.config = config
self.model = InternLMModel(config) self.model = InternLMModel(config)
vocab_size = ((config.vocab_size + 63) // 64) * 64 vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ColumnParallelLinear(config.hidden_size, self.lm_head = ColumnParallelLinear(
vocab_size, config.hidden_size,
bias=False, vocab_size,
gather_output=False, bias=False,
perform_initialization=False) gather_output=False,
)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
def forward( def forward(
......
...@@ -39,8 +39,7 @@ from vllm.model_executor.layers.sampler import Sampler ...@@ -39,8 +39,7 @@ from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.quantized_linear import ParallelLinear from vllm.model_executor.layers.quantized_linear import ParallelLinear
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding
VocabParallelEmbedding)
from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.quantization_utils import QuantizationConfig
from vllm.model_executor.weight_utils import ( from vllm.model_executor.weight_utils import (
convert_pyslice_to_tensor, hf_model_weights_iterator, convert_pyslice_to_tensor, hf_model_weights_iterator,
...@@ -64,13 +63,11 @@ class LlamaMLP(nn.Module): ...@@ -64,13 +63,11 @@ class LlamaMLP(nn.Module):
2 * intermediate_size, 2 * intermediate_size,
bias=False, bias=False,
gather_output=False, gather_output=False,
perform_initialization=False,
quant_config=quant_config) quant_config=quant_config)
self.down_proj = ParallelLinear.row(intermediate_size, self.down_proj = ParallelLinear.row(intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False,
quant_config=quant_config) quant_config=quant_config)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
...@@ -127,7 +124,6 @@ class LlamaAttention(nn.Module): ...@@ -127,7 +124,6 @@ class LlamaAttention(nn.Module):
self.head_dim, self.head_dim,
bias=False, bias=False,
gather_output=False, gather_output=False,
perform_initialization=False,
quant_config=quant_config, quant_config=quant_config,
) )
self.o_proj = ParallelLinear.row( self.o_proj = ParallelLinear.row(
...@@ -135,7 +131,6 @@ class LlamaAttention(nn.Module): ...@@ -135,7 +131,6 @@ class LlamaAttention(nn.Module):
hidden_size, hidden_size,
bias=False, bias=False,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False,
quant_config=quant_config, quant_config=quant_config,
) )
self.attn = PagedAttentionWithRoPE( self.attn = PagedAttentionWithRoPE(
...@@ -241,7 +236,9 @@ class LlamaModel(nn.Module): ...@@ -241,7 +236,9 @@ class LlamaModel(nn.Module):
vocab_size = ((config.vocab_size + 63) // 64) * 64 vocab_size = ((config.vocab_size + 63) // 64) * 64
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
vocab_size, config.hidden_size, perform_initialization=False) vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
LlamaDecoderLayer(config, quant_config) LlamaDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
...@@ -291,7 +288,6 @@ class LlamaForCausalLM(nn.Module): ...@@ -291,7 +288,6 @@ class LlamaForCausalLM(nn.Module):
vocab_size, vocab_size,
bias=False, bias=False,
gather_output=False, gather_output=False,
perform_initialization=False,
quant_config=None) quant_config=None)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
......
...@@ -38,8 +38,7 @@ from vllm.model_executor.layers.sampler import Sampler ...@@ -38,8 +38,7 @@ from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.quantized_linear import ParallelLinear from vllm.model_executor.layers.quantized_linear import ParallelLinear
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding
VocabParallelEmbedding)
from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.quantization_utils import QuantizationConfig
from vllm.model_executor.weight_utils import ( from vllm.model_executor.weight_utils import (
convert_pyslice_to_tensor, hf_model_weights_iterator, convert_pyslice_to_tensor, hf_model_weights_iterator,
...@@ -64,13 +63,11 @@ class MistralMLP(nn.Module): ...@@ -64,13 +63,11 @@ class MistralMLP(nn.Module):
2 * intermediate_size, 2 * intermediate_size,
bias=False, bias=False,
gather_output=False, gather_output=False,
perform_initialization=False,
quant_config=quant_config) quant_config=quant_config)
self.down_proj = ParallelLinear.row(intermediate_size, self.down_proj = ParallelLinear.row(intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False,
quant_config=quant_config) quant_config=quant_config)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
...@@ -116,7 +113,6 @@ class MistralAttention(nn.Module): ...@@ -116,7 +113,6 @@ class MistralAttention(nn.Module):
self.head_dim, self.head_dim,
bias=False, bias=False,
gather_output=False, gather_output=False,
perform_initialization=False,
quant_config=quant_config, quant_config=quant_config,
) )
self.o_proj = ParallelLinear.row( self.o_proj = ParallelLinear.row(
...@@ -124,7 +120,6 @@ class MistralAttention(nn.Module): ...@@ -124,7 +120,6 @@ class MistralAttention(nn.Module):
hidden_size, hidden_size,
bias=False, bias=False,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False,
quant_config=quant_config, quant_config=quant_config,
) )
self.attn = PagedAttentionWithRoPE(self.num_heads, self.attn = PagedAttentionWithRoPE(self.num_heads,
...@@ -225,7 +220,9 @@ class MistralModel(nn.Module): ...@@ -225,7 +220,9 @@ class MistralModel(nn.Module):
vocab_size = ((config.vocab_size + 63) // 64) * 64 vocab_size = ((config.vocab_size + 63) // 64) * 64
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
vocab_size, config.hidden_size, perform_initialization=False) vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
MistralDecoderLayer(config, quant_config) MistralDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
...@@ -275,7 +272,6 @@ class MistralForCausalLM(nn.Module): ...@@ -275,7 +272,6 @@ class MistralForCausalLM(nn.Module):
vocab_size, vocab_size,
bias=False, bias=False,
gather_output=False, gather_output=False,
perform_initialization=False,
quant_config=None) quant_config=None)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
......
...@@ -15,8 +15,9 @@ from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor, ...@@ -15,8 +15,9 @@ from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
load_tensor_parallel_weights) load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.mpt import MPTConfig
...@@ -53,7 +54,6 @@ class MPTAttention(nn.Module): ...@@ -53,7 +54,6 @@ class MPTAttention(nn.Module):
3 * self.d_model, 3 * self.d_model,
bias=not config.no_bias, bias=not config.no_bias,
gather_output=False, gather_output=False,
perform_initialization=False,
) )
if self.qk_ln: if self.qk_ln:
self.q_ln = nn.LayerNorm(self.d_model) self.q_ln = nn.LayerNorm(self.d_model)
...@@ -63,7 +63,6 @@ class MPTAttention(nn.Module): ...@@ -63,7 +63,6 @@ class MPTAttention(nn.Module):
self.d_model, self.d_model,
bias=not config.no_bias, bias=not config.no_bias,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False,
) )
tp_world_size = get_tensor_model_parallel_world_size() tp_world_size = get_tensor_model_parallel_world_size()
...@@ -113,17 +112,19 @@ class MPTMLP(nn.Module): ...@@ -113,17 +112,19 @@ class MPTMLP(nn.Module):
hidden_size = config.d_model hidden_size = config.d_model
expansion_ratio = config.expansion_ratio expansion_ratio = config.expansion_ratio
intermediate_size = expansion_ratio * hidden_size intermediate_size = expansion_ratio * hidden_size
self.up_proj = ColumnParallelLinear(hidden_size, self.up_proj = ColumnParallelLinear(
intermediate_size, hidden_size,
bias=not config.no_bias, intermediate_size,
gather_output=False, bias=not config.no_bias,
perform_initialization=False) gather_output=False,
)
self.act = get_act_fn("gelu") self.act = get_act_fn("gelu")
self.down_proj = RowParallelLinear(intermediate_size, self.down_proj = RowParallelLinear(
hidden_size, intermediate_size,
bias=not config.no_bias, hidden_size,
input_is_parallel=True, bias=not config.no_bias,
perform_initialization=False) input_is_parallel=True,
)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.up_proj(x) x, _ = self.up_proj(x)
...@@ -172,9 +173,10 @@ class MPTModel(nn.Module): ...@@ -172,9 +173,10 @@ class MPTModel(nn.Module):
assert config.embedding_fraction == 1.0 assert config.embedding_fraction == 1.0
assert config.norm_type == "low_precision_layernorm" assert config.norm_type == "low_precision_layernorm"
self.wte = VocabParallelEmbedding(config.vocab_size, self.wte = VocabParallelEmbedding(
config.d_model, config.vocab_size,
perform_initialization=False) config.d_model,
)
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[MPTBlock(config) for _ in range(config.n_layers)]) [MPTBlock(config) for _ in range(config.n_layers)])
self.norm_f = nn.LayerNorm(config.d_model) self.norm_f = nn.LayerNorm(config.d_model)
......
...@@ -35,8 +35,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator, ...@@ -35,8 +35,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights) load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -73,16 +74,18 @@ class OPTAttention(nn.Module): ...@@ -73,16 +74,18 @@ class OPTAttention(nn.Module):
self.head_dim = embed_dim // total_num_heads self.head_dim = embed_dim // total_num_heads
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.qkv_proj = ColumnParallelLinear(embed_dim, self.qkv_proj = ColumnParallelLinear(
3 * embed_dim, embed_dim,
bias=bias, 3 * embed_dim,
gather_output=False, bias=bias,
perform_initialization=False) gather_output=False,
self.out_proj = RowParallelLinear(embed_dim, )
embed_dim, self.out_proj = RowParallelLinear(
bias=bias, embed_dim,
input_is_parallel=True, embed_dim,
perform_initialization=False) bias=bias,
input_is_parallel=True,
)
self.attn = PagedAttention(self.num_heads, self.attn = PagedAttention(self.num_heads,
self.head_dim, self.head_dim,
scale=self.scaling) scale=self.scaling)
...@@ -120,16 +123,18 @@ class OPTDecoderLayer(nn.Module): ...@@ -120,16 +123,18 @@ class OPTDecoderLayer(nn.Module):
self.self_attn_layer_norm = nn.LayerNorm( self.self_attn_layer_norm = nn.LayerNorm(
self.embed_dim, self.embed_dim,
elementwise_affine=config.layer_norm_elementwise_affine) elementwise_affine=config.layer_norm_elementwise_affine)
self.fc1 = ColumnParallelLinear(self.embed_dim, self.fc1 = ColumnParallelLinear(
config.ffn_dim, self.embed_dim,
bias=config.enable_bias, config.ffn_dim,
gather_output=False, bias=config.enable_bias,
perform_initialization=False) gather_output=False,
self.fc2 = RowParallelLinear(config.ffn_dim, )
self.embed_dim, self.fc2 = RowParallelLinear(
bias=config.enable_bias, config.ffn_dim,
input_is_parallel=True, self.embed_dim,
perform_initialization=False) bias=config.enable_bias,
input_is_parallel=True,
)
self.final_layer_norm = nn.LayerNorm( self.final_layer_norm = nn.LayerNorm(
self.embed_dim, self.embed_dim,
elementwise_affine=config.layer_norm_elementwise_affine) elementwise_affine=config.layer_norm_elementwise_affine)
...@@ -182,7 +187,7 @@ class OPTDecoder(nn.Module): ...@@ -182,7 +187,7 @@ class OPTDecoder(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.word_embed_proj_dim, config.word_embed_proj_dim,
perform_initialization=False) )
# Positional embeddings are replicated (not sharded). # Positional embeddings are replicated (not sharded).
self.embed_positions = OPTLearnedPositionalEmbedding( self.embed_positions = OPTLearnedPositionalEmbedding(
config.max_position_embeddings, config.hidden_size) config.max_position_embeddings, config.hidden_size)
......
...@@ -28,7 +28,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( ...@@ -28,7 +28,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.layers import (
VocabParallelEmbedding, VocabParallelEmbedding,
ColumnParallelLinear, ColumnParallelLinear,
RowParallelLinear, RowParallelLinear,
...@@ -53,14 +53,12 @@ class QWenMLP(nn.Module): ...@@ -53,14 +53,12 @@ class QWenMLP(nn.Module):
2 * intermediate_size, 2 * intermediate_size,
bias=False, bias=False,
gather_output=False, gather_output=False,
perform_initialization=False,
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False,
) )
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
...@@ -98,14 +96,12 @@ class QWenAttention(nn.Module): ...@@ -98,14 +96,12 @@ class QWenAttention(nn.Module):
3 * hidden_size, 3 * hidden_size,
bias=True, bias=True,
gather_output=False, gather_output=False,
perform_initialization=False,
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False,
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithRoPE( self.attn = PagedAttentionWithRoPE(
...@@ -190,9 +186,10 @@ class QWenModel(nn.Module): ...@@ -190,9 +186,10 @@ class QWenModel(nn.Module):
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
vocab_size = ((config.vocab_size + 63) // 64) * 64 vocab_size = ((config.vocab_size + 63) // 64) * 64
self.wte = VocabParallelEmbedding(vocab_size, self.wte = VocabParallelEmbedding(
config.hidden_size, vocab_size,
perform_initialization=False) config.hidden_size,
)
self.h = nn.ModuleList( self.h = nn.ModuleList(
[QWenBlock(config) for _ in range(config.num_hidden_layers)]) [QWenBlock(config) for _ in range(config.num_hidden_layers)])
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
...@@ -235,7 +232,6 @@ class QWenLMHeadModel(nn.Module): ...@@ -235,7 +232,6 @@ class QWenLMHeadModel(nn.Module):
vocab_size, vocab_size,
bias=False, bias=False,
gather_output=False, gather_output=False,
perform_initialization=False,
) )
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
......
import vllm.model_executor.parallel_utils.parallel_state
import vllm.model_executor.parallel_utils.tensor_parallel
__all__ = [
"parallel_state",
"tensor_parallel",
]
import torch
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_group,
)
def tensor_model_parallel_all_reduce(input_):
"""All-reduce the input tensor across model parallel group.
Note: This operation is applied in-place on the input tensor.
"""
# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size() == 1:
return input_
# All-reduce.
torch.distributed.all_reduce(input_,
group=get_tensor_model_parallel_group())
return input_
def tensor_model_parallel_all_gather(input_, dim=-1):
"""All-gather the input tensor across model parallel group."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# Allocate output tensor.
output_tensor = torch.empty((world_size, ) + input_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
torch.distributed.all_gather_into_tensor(
output_tensor, input_, group=get_tensor_model_parallel_group())
# Reshape
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(world_size * input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py # Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Parts of the code here are adapted from PyTorch # Parts of the code here are adapted from PyTorch
...@@ -8,60 +9,22 @@ from typing import Optional ...@@ -8,60 +9,22 @@ from typing import Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from .mappings import ( from vllm.model_executor.quantization_utils import QuantizationConfig
gather_from_tensor_model_parallel_region, from vllm.model_executor.parallel_utils.communication_op import (
reduce_from_tensor_model_parallel_region, tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather)
scatter_to_tensor_model_parallel_region,
)
from .utils import ( from vllm.model_executor.parallel_utils.utils import (
divide, divide,
VocabUtility, VocabUtility,
split_tensor_along_last_dim,
) )
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
'partition_dim': -1,
'partition_stride': 1}
def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, 'tensor_model_parallel') and
param.tensor_model_parallel) or (
get_tensor_model_parallel_rank() == 0)
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
# Make sure the attributes are not set.
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
assert not hasattr(tensor, attribute)
# Set the attributes.
setattr(tensor, 'tensor_model_parallel', is_parallel)
setattr(tensor, 'partition_dim', dim)
setattr(tensor, 'partition_stride', stride)
def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
def maybe_set(attribute, value):
if not hasattr(tensor, attribute):
setattr(tensor, attribute, value)
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute])
def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
def maybe_copy(attribute):
if hasattr(source_tensor, attribute):
setattr(destination_tensor, attribute,
getattr(source_tensor, attribute))
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
maybe_copy(attribute)
class VocabParallelEmbedding(torch.nn.Module): class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension. """Embedding parallelized in the vocabulary dimension.
...@@ -71,22 +34,14 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -71,22 +34,14 @@ class VocabParallelEmbedding(torch.nn.Module):
Arguments: Arguments:
num_embeddings: vocabulary size. num_embeddings: vocabulary size.
embedding_dim: size of hidden state. embedding_dim: size of hidden state.
params_dtype: type of the parameters.
Keyword Arguments:
init_method: method to initialize weights.
params_dtype
use_cpu_initialization
perform_initialization
""" """
def __init__(self, num_embeddings: int, embedding_dim: int, *, def __init__(self,
init_method=init.xavier_normal_, num_embeddings: int,
params_dtype: torch.dtype=None, embedding_dim: int,
use_cpu_initialization: bool=False, params_dtype: Optional[torch.dtype] = None):
perform_initialization: bool=False): super().__init__()
super(VocabParallelEmbedding, self).__init__()
assert not perform_initialization
assert not use_cpu_initialization
# Keep the input dimensions. # Keep the input dimensions.
self.num_embeddings = num_embeddings self.num_embeddings = num_embeddings
...@@ -94,46 +49,39 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -94,46 +49,39 @@ class VocabParallelEmbedding(torch.nn.Module):
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
# Set the defaults for compatibility. self.tp_size = get_tensor_model_parallel_world_size()
self.padding_idx = None # TODO: Handle vocab padding here.
self.max_norm = None
self.norm_type = 2.
self.scale_grad_by_freq = False
self.sparse = False
self._weight = None
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension. # Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = \ self.vocab_start_index, self.vocab_end_index = (
VocabUtility.vocab_range_from_global_vocab_size( VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_tensor_model_parallel_rank(), self.num_embeddings, get_tensor_model_parallel_rank(),
self.tensor_model_parallel_size) self.tp_size))
self.num_embeddings_per_partition = self.vocab_end_index - \ self.num_embeddings_per_partition = (self.vocab_end_index -
self.vocab_start_index self.vocab_start_index)
self.weight = Parameter(torch.empty( self.weight = Parameter(
self.num_embeddings_per_partition, self.embedding_dim, torch.empty(self.num_embeddings_per_partition,
device=torch.cuda.current_device(), dtype=params_dtype)) self.embedding_dim,
device=torch.cuda.current_device(),
dtype=params_dtype))
def forward(self, input_): def forward(self, input_):
if self.tensor_model_parallel_size > 1: if self.tp_size > 1:
# Build the mask. # Build the mask.
input_mask = (input_ < self.vocab_start_index) | \ input_mask = ((input_ < self.vocab_start_index) |
(input_ >= self.vocab_end_index) (input_ >= self.vocab_end_index))
# Mask the input. # Mask the input.
masked_input = input_.clone() - self.vocab_start_index masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0 masked_input[input_mask] = 0
else: else:
masked_input = input_ masked_input = input_
# Get the embeddings. # Get the embeddings.
output_parallel = F.embedding(masked_input, self.weight, output_parallel = F.embedding(masked_input, self.weight)
self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq,
self.sparse)
# Mask the output embedding. # Mask the output embedding.
if self.tensor_model_parallel_size > 1: if self.tp_size > 1:
output_parallel[input_mask, :] = 0.0 output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs. # Reduce across all the model parallel GPUs.
output = reduce_from_tensor_model_parallel_region(output_parallel) output = tensor_model_parallel_all_reduce(output_parallel)
return output return output
...@@ -152,40 +100,32 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -152,40 +100,32 @@ class ColumnParallelLinear(torch.nn.Module):
gather_output: If true, call all-gather on output and make Y available gather_output: If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i which is Y_i = XA_i
init_method: method to initialize weights. Note that bias is always set skip_bias_add: This was added to enable performance optimizations where
to zero. bias can be fused with other element-wise operations. we
stride: For the strided linear layers. skip adding bias but instead return it.
keep_master_weight_for_test: This was added for testing and should be params_dtype: Data type for the parameters.
set to False. It returns the master weights quant_config: Quantization configuration.
used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
adding bias but instead return it.
params_dtype:
use_cpu_initialization:
""" """
def __init__(self, input_size, output_size, *, def __init__(
bias=True, gather_output=True, self,
init_method=init.xavier_normal_, stride=1, input_size: int,
keep_master_weight_for_test=False, output_size: int,
skip_bias_add=False, bias: bool = True,
params_dtype=None, gather_output: bool = True,
use_cpu_initialization=False, skip_bias_add: bool = False,
perform_initialization=False, params_dtype: Optional[torch.dtype] = None,
quant_config=None, quant_config: Optional[QuantizationConfig] = None,
): ):
super(ColumnParallelLinear, self).__init__() super().__init__()
assert not perform_initialization
assert not use_cpu_initialization
# Keep input parameters # Keep input parameters
self.input_size = input_size self.input_size = input_size
self.output_size = output_size self.output_size = output_size
self.gather_output = gather_output self.gather_output = gather_output
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
self.world_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, self.world_size) self.output_size_per_partition = divide(output_size, self.tp_size)
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.quant_config = quant_config self.quant_config = quant_config
...@@ -198,21 +138,19 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -198,21 +138,19 @@ class ColumnParallelLinear(torch.nn.Module):
self.create_weights(params_dtype) self.create_weights(params_dtype)
if bias: if bias:
self.bias = Parameter(torch.empty( self.bias = Parameter(
self.output_size_per_partition, torch.empty(self.output_size_per_partition,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=params_dtype)) dtype=params_dtype))
set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
def create_weights(self, dtype: torch.dtype) -> None: def create_weights(self, dtype: torch.dtype) -> None:
self.weight = Parameter(torch.empty( self.weight = Parameter(
self.output_size_per_partition, self.input_size, torch.empty(self.output_size_per_partition,
device=torch.cuda.current_device(), dtype=dtype)) self.input_size,
device=torch.cuda.current_device(),
dtype=dtype))
def apply_weights( def apply_weights(
self, self,
...@@ -225,7 +163,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -225,7 +163,7 @@ class ColumnParallelLinear(torch.nn.Module):
"""Forward of ColumnParallelLinear """Forward of ColumnParallelLinear
Args: Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden] input_: Tensor whose last dimension is `input_size`.
Returns: Returns:
- output - output
...@@ -238,7 +176,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -238,7 +176,7 @@ class ColumnParallelLinear(torch.nn.Module):
output_parallel = self.apply_weights(input_parallel, bias) output_parallel = self.apply_weights(input_parallel, bias)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = gather_from_tensor_model_parallel_region(output_parallel) output = tensor_model_parallel_all_gather(output_parallel)
else: else:
output = output_parallel output = output_parallel
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
...@@ -266,36 +204,25 @@ class RowParallelLinear(torch.nn.Module): ...@@ -266,36 +204,25 @@ class RowParallelLinear(torch.nn.Module):
input_is_parallel: If true, we assume that the input is already input_is_parallel: If true, we assume that the input is already
split across the GPUs and we do not split split across the GPUs and we do not split
again. again.
init_method: method to initialize weights. Note that bias is always set skip_bias_add: This was added to enable performance optimization where
to zero. bias can be fused with other element-wise operations.
stride: For the strided linear layers. We skip adding bias but instead return it.
keep_master_weight_for_test: This was added for testing and should be params_dtype: Data type for the parameters.
set to False. It returns the master weights quant_config: Quantization configuration.
used for initialization.
skip_bias_add: This was added to enable performance optimization where bias
can be fused with other elementwise operations. We skip
adding bias but instead return it.
params_dtype:
use_cpu_initialization:
perform_initialization:
reduce_results:
""" """
def __init__(self, input_size, output_size, *, def __init__(
bias=True, input_is_parallel=False, self,
init_method=init.xavier_normal_, stride=1, input_size: int,
keep_master_weight_for_test=False, output_size: int,
skip_bias_add=False, bias: bool = True,
params_dtype=None, input_is_parallel: bool = False,
use_cpu_initialization=False, skip_bias_add: bool = False,
perform_initialization=False, params_dtype: Optional[torch.dtype] = None,
reduce_results=True, reduce_results: bool = True,
quant_config=None, quant_config: Optional[QuantizationConfig] = None,
): ):
super(RowParallelLinear, self).__init__() super().__init__()
assert not perform_initialization
assert not use_cpu_initialization
# Keep input parameters # Keep input parameters
self.input_size = input_size self.input_size = input_size
self.output_size = output_size self.output_size = output_size
...@@ -305,21 +232,22 @@ class RowParallelLinear(torch.nn.Module): ...@@ -305,21 +232,22 @@ class RowParallelLinear(torch.nn.Module):
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
self.world_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.world_size) self.input_size_per_partition = divide(input_size, self.tp_size)
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.quant_config = quant_config self.quant_config = quant_config
self.create_weights(params_dtype) self.create_weights(params_dtype)
if not reduce_results and (bias and not skip_bias_add): if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the " raise ValueError('When not reduce the results, adding bias to the '
"results can lead to incorrect results") 'results can lead to incorrect results')
if bias: if bias:
self.bias = Parameter(torch.empty( self.bias = Parameter(
self.output_size, device=torch.cuda.current_device(), torch.empty(self.output_size,
dtype=params_dtype)) device=torch.cuda.current_device(),
dtype=params_dtype))
# Always initialize bias to zero. # Always initialize bias to zero.
with torch.no_grad(): with torch.no_grad():
...@@ -328,9 +256,11 @@ class RowParallelLinear(torch.nn.Module): ...@@ -328,9 +256,11 @@ class RowParallelLinear(torch.nn.Module):
self.register_parameter('bias', None) self.register_parameter('bias', None)
def create_weights(self, dtype: torch.dtype) -> None: def create_weights(self, dtype: torch.dtype) -> None:
self.weight = Parameter(torch.empty( self.weight = Parameter(
self.output_size, self.input_size_per_partition, torch.empty(self.output_size,
device=torch.cuda.current_device(), dtype=dtype)) self.input_size_per_partition,
device=torch.cuda.current_device(),
dtype=dtype))
def apply_weights(self, x: torch.Tensor) -> torch.Tensor: def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight) return F.linear(x, self.weight)
...@@ -339,7 +269,9 @@ class RowParallelLinear(torch.nn.Module): ...@@ -339,7 +269,9 @@ class RowParallelLinear(torch.nn.Module):
"""Forward of RowParallelLinear """Forward of RowParallelLinear
Args: Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden] input_: tensor whose last dimension is `input_size`. If
`input_is_parallel` is set, then the last dimension
is `input_size // tp_size`.
Returns: Returns:
- output - output
...@@ -349,11 +281,16 @@ class RowParallelLinear(torch.nn.Module): ...@@ -349,11 +281,16 @@ class RowParallelLinear(torch.nn.Module):
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
else: else:
input_parallel = scatter_to_tensor_model_parallel_region(input_) # TODO: simplify code below
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply. # Matrix multiply.
output_parallel = self.apply_weights(input_parallel) output_parallel = self.apply_weights(input_parallel)
if self.reduce_results and self.world_size > 1: if self.reduce_results and self.tp_size > 1:
output_ = reduce_from_tensor_model_parallel_region(output_parallel) output_ = tensor_model_parallel_all_reduce(output_parallel)
else: else:
output_ = output_parallel output_ = output_parallel
......
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py # Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Model and data parallel groups.""" """Model and data parallel groups."""
import torch import torch
from typing import Optional
# Intra-layer model parallel group that the current rank belongs to. # Tensor model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None _TENSOR_MODEL_PARALLEL_GROUP = None
# Inter-layer model parallel group that the current rank belongs to. # Pipeline model parallel group that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP = None _PIPELINE_MODEL_PARALLEL_GROUP = None
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
_MODEL_PARALLEL_GROUP = None
# Embedding group.
_EMBEDDING_GROUP = None
# Position embedding group.
_POSITION_EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
# These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS = None
# A list of ranks that have a copy of the position embedding.
_POSITION_EMBEDDING_GLOBAL_RANKS = None
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS = None
# A list of global ranks for each data parallel group to ease calculation of the source # A list of global ranks for each pipeline group to ease calculation of the
# rank when broadcasting weights from src to all other data parallel ranks # source rank when broadcasting from the first or last pipeline stage.
_DATA_PARALLEL_GLOBAL_RANKS = None _PIPELINE_GLOBAL_RANKS = None
def initialize_model_parallel( def initialize_model_parallel(
tensor_model_parallel_size: int = 1, tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: Optional[int] = None,
pipeline_model_parallel_split_rank: Optional[int] = None,
) -> None: ) -> None:
""" """
Initialize model data parallel groups. Initialize model parallel groups.
Arguments: Arguments:
tensor_model_parallel_size: number of GPUs used for tensor model parallelism. tensor_model_parallel_size: number of GPUs used for tensor model
pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism. parallelism.
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved pipeline_model_parallel_size: number of GPUs used for pipeline model
pipeline). parallelism.
pipeline_model_parallel_split_rank: for models with both encoder and decoder,
rank in pipeline with split point. Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will the model pipeline. The present function will
create 8 tensor model-parallel groups, 4 pipeline model-parallel groups create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
and 8 data-parallel groups as: 4 tensor model-parallel groups:
8 data_parallel groups: [g0, g1], [g2, g3], [g4, g5], [g6, g7]
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] 2 pipeline model-parallel groups:
8 tensor model-parallel groups: [g0, g2, g4, g6], [g1, g3, g5, g7]
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
4 pipeline model-parallel groups:
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
Note that for efficiency, the caller should make sure adjacent ranks Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and with a total of 16 GPUs, rank 0 to 7 belong to the first box and
...@@ -82,64 +46,23 @@ def initialize_model_parallel( ...@@ -82,64 +46,23 @@ def initialize_model_parallel(
assert torch.distributed.is_initialized() assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size() world_size: int = torch.distributed.get_world_size()
if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0: if (world_size !=
tensor_model_parallel_size * pipeline_model_parallel_size):
raise RuntimeError( raise RuntimeError(
f"world_size ({world_size}) is not divisible by tensor_model_parallel_size " f"world_size ({world_size}) is not equal to "
f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})" f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
) f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
data_parallel_size: int = world_size // (tensor_model_parallel_size * num_tensor_model_parallel_groups: int = (world_size //
pipeline_model_parallel_size) tensor_model_parallel_size)
num_pipeline_model_parallel_groups: int = (world_size //
num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size pipeline_model_parallel_size)
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
num_data_parallel_groups: int = world_size // data_parallel_size
if virtual_pipeline_model_parallel_size is not None:
if not pipeline_model_parallel_size > 2:
raise RuntimeError("pipeline-model-parallel size should be greater than 2 with "
"interleaved schedule")
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size
if pipeline_model_parallel_split_rank is not None:
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
# Build the data-parallel groups.
global _DATA_PARALLEL_GROUP
global _DATA_PARALLEL_GLOBAL_RANKS
assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'
all_data_parallel_group_ranks = []
for i in range(pipeline_model_parallel_size):
start_rank = i * num_pipeline_model_parallel_groups
end_rank = (i + 1) * num_pipeline_model_parallel_groups
for j in range(tensor_model_parallel_size):
ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)
all_data_parallel_group_ranks.append(list(ranks))
group = torch.distributed.new_group(ranks)
if rank in ranks:
_DATA_PARALLEL_GROUP = group
_DATA_PARALLEL_GLOBAL_RANKS = ranks
# Build the model-parallel groups.
global _MODEL_PARALLEL_GROUP
assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'
for i in range(data_parallel_size):
ranks = [data_parallel_group_ranks[i]
for data_parallel_group_ranks in all_data_parallel_group_ranks]
group = torch.distributed.new_group(ranks)
if rank in ranks:
_MODEL_PARALLEL_GROUP = group
# Build the tensor model-parallel groups. # Build the tensor model-parallel groups.
global _TENSOR_MODEL_PARALLEL_GROUP global _TENSOR_MODEL_PARALLEL_GROUP
assert _TENSOR_MODEL_PARALLEL_GROUP is None, \ assert _TENSOR_MODEL_PARALLEL_GROUP is None, (
'tensor model parallel group is already initialized' "tensor model parallel group is already initialized")
for i in range(num_tensor_model_parallel_groups): for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size, ranks = range(i * tensor_model_parallel_size,
(i + 1) * tensor_model_parallel_size) (i + 1) * tensor_model_parallel_size)
...@@ -147,268 +70,60 @@ def initialize_model_parallel( ...@@ -147,268 +70,60 @@ def initialize_model_parallel(
if rank in ranks: if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group _TENSOR_MODEL_PARALLEL_GROUP = group
# Build the pipeline model-parallel groups and embedding groups # Build the pipeline model-parallel groups.
# (first and last rank in each pipeline model-parallel group).
global _PIPELINE_MODEL_PARALLEL_GROUP global _PIPELINE_MODEL_PARALLEL_GROUP
global _PIPELINE_GLOBAL_RANKS global _PIPELINE_GLOBAL_RANKS
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \ assert _PIPELINE_MODEL_PARALLEL_GROUP is None, (
'pipeline model parallel group is already initialized' "pipeline model parallel group is already initialized")
global _EMBEDDING_GROUP
global _EMBEDDING_GLOBAL_RANKS
assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'
global _POSITION_EMBEDDING_GROUP
global _POSITION_EMBEDDING_GLOBAL_RANKS
assert _POSITION_EMBEDDING_GROUP is None, \
'position embedding group is already initialized'
for i in range(num_pipeline_model_parallel_groups): for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size, num_pipeline_model_parallel_groups) ranks = range(i, world_size, num_pipeline_model_parallel_groups)
group = torch.distributed.new_group(ranks) group = torch.distributed.new_group(ranks)
if rank in ranks: if rank in ranks:
_PIPELINE_MODEL_PARALLEL_GROUP = group _PIPELINE_MODEL_PARALLEL_GROUP = group
_PIPELINE_GLOBAL_RANKS = ranks _PIPELINE_GLOBAL_RANKS = ranks
# Setup embedding group (to exchange gradients between
# first and last stages).
if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]]
position_embedding_ranks = [ranks[0]]
if pipeline_model_parallel_split_rank is not None:
if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:
embedding_ranks = [ranks[0],
ranks[pipeline_model_parallel_split_rank],
ranks[-1]]
if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:
position_embedding_ranks = [ranks[0],
ranks[pipeline_model_parallel_split_rank]]
else:
embedding_ranks = ranks
position_embedding_ranks = ranks
group = torch.distributed.new_group(embedding_ranks)
if rank in embedding_ranks:
_EMBEDDING_GROUP = group
if rank in ranks:
_EMBEDDING_GLOBAL_RANKS = embedding_ranks
group = torch.distributed.new_group(position_embedding_ranks)
if rank in position_embedding_ranks:
_POSITION_EMBEDDING_GROUP = group
if rank in ranks:
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
def model_parallel_is_initialized(): def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized.""" """Check if model and data parallel groups are initialized."""
if _TENSOR_MODEL_PARALLEL_GROUP is None or \ return (_TENSOR_MODEL_PARALLEL_GROUP is not None
_PIPELINE_MODEL_PARALLEL_GROUP is None or \ and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
_DATA_PARALLEL_GROUP is None:
return False
return True
def get_model_parallel_group():
"""Get the model parallel group the caller rank belongs to."""
assert _MODEL_PARALLEL_GROUP is not None, \
'model parallel group is not initialized'
return _MODEL_PARALLEL_GROUP
def get_tensor_model_parallel_group(): def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to.""" """Get the tensor model parallel group the caller rank belongs to."""
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \ assert _TENSOR_MODEL_PARALLEL_GROUP is not None, (
'intra_layer_model parallel group is not initialized' "tenosr model parallel group is not initialized")
return _TENSOR_MODEL_PARALLEL_GROUP return _TENSOR_MODEL_PARALLEL_GROUP
def get_pipeline_model_parallel_group(): def get_pipeline_model_parallel_group():
"""Get the pipeline model parallel group the caller rank belongs to.""" """Get the pipeline model parallel group the caller rank belongs to."""
assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, \ assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, (
'pipeline_model parallel group is not initialized' "pipeline model parallel group is not initialized")
return _PIPELINE_MODEL_PARALLEL_GROUP return _PIPELINE_MODEL_PARALLEL_GROUP
def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP is not None, \
'data parallel group is not initialized'
return _DATA_PARALLEL_GROUP
def get_embedding_group():
"""Get the embedding group the caller rank belongs to."""
assert _EMBEDDING_GROUP is not None, \
'embedding group is not initialized'
return _EMBEDDING_GROUP
def get_position_embedding_group():
"""Get the position embedding group the caller rank belongs to."""
assert _POSITION_EMBEDDING_GROUP is not None, \
'position embedding group is not initialized'
return _POSITION_EMBEDDING_GROUP
def set_tensor_model_parallel_world_size(world_size):
"""Set the tensor model parallel size"""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size
def set_pipeline_model_parallel_world_size(world_size):
"""Set the pipeline model parallel size"""
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
def get_tensor_model_parallel_world_size(): def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group.""" """Return world size for the tensor model parallel group."""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE return torch.distributed.get_world_size(
if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None: group=get_tensor_model_parallel_group())
return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
def get_pipeline_model_parallel_world_size(): def get_pipeline_model_parallel_world_size():
"""Return world size for the pipeline model parallel group.""" """Return world size for the pipeline model parallel group."""
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE return torch.distributed.get_world_size(
if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None: group=get_pipeline_model_parallel_group())
return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
def set_tensor_model_parallel_rank(rank):
"""Set tensor model parallel rank."""
global _MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK = rank
def set_pipeline_model_parallel_rank(rank):
"""Set pipeline model parallel rank."""
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
def set_pipeline_model_parallel_split_rank(rank):
"""Set pipeline model parallel split rank."""
global _MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank
def get_tensor_model_parallel_rank(): def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group.""" """Return my rank for the tensor model parallel group."""
global _MPU_TENSOR_MODEL_PARALLEL_RANK
if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
return _MPU_TENSOR_MODEL_PARALLEL_RANK
return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
def get_pipeline_model_parallel_rank(): def get_pipeline_model_parallel_rank():
"""Return my rank for the pipeline model parallel group.""" """Return my rank for the pipeline model parallel group."""
global _MPU_PIPELINE_MODEL_PARALLEL_RANK return torch.distributed.get_rank(
if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None: group=get_pipeline_model_parallel_group())
return _MPU_PIPELINE_MODEL_PARALLEL_RANK
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
if get_virtual_pipeline_model_parallel_world_size() is not None and \
get_virtual_pipeline_model_parallel_rank() != 0:
return False
return get_pipeline_model_parallel_rank() == 0
def is_pipeline_last_stage(ignore_virtual=False):
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
virtual_pipeline_model_parallel_world_size = \
get_virtual_pipeline_model_parallel_world_size()
if virtual_pipeline_model_parallel_world_size is not None and \
get_virtual_pipeline_model_parallel_rank() != (
virtual_pipeline_model_parallel_world_size - 1):
return False
return get_pipeline_model_parallel_rank() == (
get_pipeline_model_parallel_world_size() - 1)
def is_rank_in_embedding_group(ignore_virtual=False):
"""Return true if current rank is in embedding group, False otherwise."""
rank = torch.distributed.get_rank()
global _EMBEDDING_GLOBAL_RANKS
if ignore_virtual:
return rank in _EMBEDDING_GLOBAL_RANKS
if rank in _EMBEDDING_GLOBAL_RANKS:
if rank == _EMBEDDING_GLOBAL_RANKS[0]:
return is_pipeline_first_stage(ignore_virtual=False)
elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
return is_pipeline_last_stage(ignore_virtual=False)
else:
return True
return False
def is_rank_in_position_embedding_group():
"""Return true if current rank is in position embedding group, False otherwise."""
rank = torch.distributed.get_rank()
global _POSITION_EMBEDDING_GLOBAL_RANKS
return rank in _POSITION_EMBEDDING_GLOBAL_RANKS
def is_pipeline_stage_before_split(rank=None):
"""Return True if pipeline stage executes encoder block for a model
with both encoder and decoder."""
if get_pipeline_model_parallel_world_size() == 1:
return True
if rank is None:
rank = get_pipeline_model_parallel_rank()
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True
if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
return True
return False
def is_pipeline_stage_after_split(rank=None):
"""Return True if pipeline stage executes decoder block for a model
with both encoder and decoder."""
if get_pipeline_model_parallel_world_size() == 1:
return True
if rank is None:
rank = get_pipeline_model_parallel_rank()
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True
if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
return True
return False
def is_pipeline_stage_at_split():
"""Return true if pipeline stage executes decoder block and next
stage executes encoder block for a model with both encoder and
decoder."""
rank = get_pipeline_model_parallel_rank()
return is_pipeline_stage_before_split(rank) and \
is_pipeline_stage_after_split(rank+1)
def get_virtual_pipeline_model_parallel_rank():
"""Return the virtual pipeline-parallel rank."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
def set_virtual_pipeline_model_parallel_rank(rank):
"""Set the virtual pipeline-parallel rank."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank
def get_virtual_pipeline_model_parallel_world_size():
"""Return the virtual pipeline-parallel world size."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
def get_tensor_model_parallel_src_rank(): def get_tensor_model_parallel_src_rank():
...@@ -419,35 +134,27 @@ def get_tensor_model_parallel_src_rank(): ...@@ -419,35 +134,27 @@ def get_tensor_model_parallel_src_rank():
return (global_rank // local_world_size) * local_world_size return (global_rank // local_world_size) * local_world_size
def get_data_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the data parallel group."""
assert _DATA_PARALLEL_GLOBAL_RANKS is not None, \
"Data parallel group is not initialized"
return _DATA_PARALLEL_GLOBAL_RANKS[0]
def get_pipeline_model_parallel_first_rank(): def get_pipeline_model_parallel_first_rank():
"""Return the global rank of the first process in the pipeline for the """Return the global rank of the first process in the pipeline for the
current tensor parallel group""" current tensor parallel group"""
assert _PIPELINE_GLOBAL_RANKS is not None, \ assert _PIPELINE_GLOBAL_RANKS is not None, (
"Pipeline parallel group is not initialized" "Pipeline parallel group is not initialized")
return _PIPELINE_GLOBAL_RANKS[0] return _PIPELINE_GLOBAL_RANKS[0]
def get_pipeline_model_parallel_last_rank(): def get_pipeline_model_parallel_last_rank():
"""Return the global rank of the last process in the pipeline for the """Return the global rank of the last process in the pipeline for the
current tensor parallel group""" current tensor parallel group"""
assert _PIPELINE_GLOBAL_RANKS is not None, \ assert _PIPELINE_GLOBAL_RANKS is not None, (
"Pipeline parallel group is not initialized" "Pipeline parallel group is not initialized")
last_rank_local = get_pipeline_model_parallel_world_size() - 1 last_rank_local = get_pipeline_model_parallel_world_size() - 1
return _PIPELINE_GLOBAL_RANKS[last_rank_local] return _PIPELINE_GLOBAL_RANKS[last_rank_local]
def get_pipeline_model_parallel_next_rank(): def get_pipeline_model_parallel_next_rank():
"""Return the global rank that follows the caller in the pipeline""" """Return the global rank that follows the caller in the pipeline"""
assert _PIPELINE_GLOBAL_RANKS is not None, \ assert _PIPELINE_GLOBAL_RANKS is not None, (
"Pipeline parallel group is not initialized" "Pipeline parallel group is not initialized")
rank_in_pipeline = get_pipeline_model_parallel_rank() rank_in_pipeline = get_pipeline_model_parallel_rank()
world_size = get_pipeline_model_parallel_world_size() world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
...@@ -455,45 +162,18 @@ def get_pipeline_model_parallel_next_rank(): ...@@ -455,45 +162,18 @@ def get_pipeline_model_parallel_next_rank():
def get_pipeline_model_parallel_prev_rank(): def get_pipeline_model_parallel_prev_rank():
"""Return the global rank that preceeds the caller in the pipeline""" """Return the global rank that preceeds the caller in the pipeline"""
assert _PIPELINE_GLOBAL_RANKS is not None, \ assert _PIPELINE_GLOBAL_RANKS is not None, (
"Pipeline parallel group is not initialized" "Pipeline parallel group is not initialized")
rank_in_pipeline = get_pipeline_model_parallel_rank() rank_in_pipeline = get_pipeline_model_parallel_rank()
world_size = get_pipeline_model_parallel_world_size() world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
def get_data_parallel_world_size():
"""Return world size for the data parallel group."""
return torch.distributed.get_world_size(group=get_data_parallel_group())
def get_data_parallel_rank():
"""Return my rank for the data parallel group."""
return torch.distributed.get_rank(group=get_data_parallel_group())
def destroy_model_parallel(): def destroy_model_parallel():
"""Set the groups to none.""" """Set the groups to none."""
global _MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = None
global _TENSOR_MODEL_PARALLEL_GROUP global _TENSOR_MODEL_PARALLEL_GROUP
_TENSOR_MODEL_PARALLEL_GROUP = None _TENSOR_MODEL_PARALLEL_GROUP = None
global _PIPELINE_MODEL_PARALLEL_GROUP global _PIPELINE_MODEL_PARALLEL_GROUP
_PIPELINE_MODEL_PARALLEL_GROUP = None _PIPELINE_MODEL_PARALLEL_GROUP = None
global _DATA_PARALLEL_GROUP global _PIPELINE_GLOBAL_RANKS
_DATA_PARALLEL_GROUP = None _PIPELINE_GLOBAL_RANKS = None
global _EMBEDDING_GROUP
_EMBEDDING_GROUP = None
global _POSITION_EMBEDDING_GROUP
_POSITION_EMBEDDING_GROUP = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
from .layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
set_tensor_model_parallel_attributes,
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes,
param_is_not_tensor_parallel_duplicate,
)
from .mappings import (
copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region,
gather_from_sequence_parallel_region,
reduce_from_tensor_model_parallel_region,
scatter_to_tensor_model_parallel_region,
scatter_to_sequence_parallel_region,
)
from .random import (
get_cuda_rng_tracker,
model_parallel_cuda_manual_seed,
)
from .utils import (
split_tensor_along_last_dim,
)
__all__ = [
#layers.py
"ColumnParallelLinear",
"RowParallelLinear",
"VocabParallelEmbedding",
"set_tensor_model_parallel_attributes",
"set_defaults_if_not_set_tensor_model_parallel_attributes",
"copy_tensor_model_parallel_attributes",
"param_is_not_tensor_parallel_duplicate",
# mappings.py
"copy_to_tensor_model_parallel_region",
"gather_from_tensor_model_parallel_region",
"gather_from_sequence_parallel_region",
"reduce_from_tensor_model_parallel_region",
"scatter_to_tensor_model_parallel_region",
"scatter_to_sequence_parallel_region",
# random.py
"get_cuda_rng_tracker",
"model_parallel_cuda_manual_seed",
# utils.py
"split_tensor_along_last_dim",
]
# Copyright 2023 The vLLM team.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/mappings.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_group,
)
from .utils import split_tensor_along_last_dim
def _reduce(input_):
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size()==1:
return input_
# All-reduce.
torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())
return input_
def _split_along_last_dim(input_):
"""Split the tensor along its last dimension and keep the
corresponding slice."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# Split along last dimension.
input_list = split_tensor_along_last_dim(input_, world_size)
# Note: torch.split does not create contiguous tensors by default.
rank = get_tensor_model_parallel_rank()
output = input_list[rank].contiguous()
return output
def _split_along_first_dim(input_):
"""Split the tensor along its first dimension and keep the
corresponding slice."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# Split along first dimension.
dim_size = input_.size()[0]
assert dim_size % world_size == 0, \
"First dimension of the tensor should be divisible by tensor parallel size"
local_dim_size = dim_size // world_size
rank = get_tensor_model_parallel_rank()
dim_offset = rank * local_dim_size
output = input_[dim_offset:dim_offset+local_dim_size].contiguous()
return output
def _gather_along_last_dim(input_):
"""Gather tensors and concatinate along the last dimension."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# Size and dimension.
last_dim = input_.dim() - 1
rank = get_tensor_model_parallel_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=last_dim).contiguous()
return output
def _gather_along_first_dim(input_):
"""Gather tensors and concatinate along the first dimension."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
dim_size[0] = dim_size[0] * world_size
output = torch.empty(dim_size, dtype=input_.dtype,
device=torch.cuda.current_device())
torch.distributed._all_gather_base(output, input_.contiguous(),
group=get_tensor_model_parallel_group())
return output
def _reduce_scatter_along_first_dim(input_):
"""Reduce-scatter the input tensor across model parallel group."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
assert dim_size[0] % world_size == 0, \
"First dimension of the tensor should be divisible by tensor parallel size"
dim_size[0] = dim_size[0] // world_size
output = torch.empty(dim_size, dtype=input_.dtype,
device=torch.cuda.current_device())
torch.distributed._reduce_scatter_base(output, input_.contiguous(),
group=get_tensor_model_parallel_group())
return output
class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod
def forward(ctx, input_):
return input_
@staticmethod
def backward(ctx, grad_output):
return _reduce(grad_output)
class _ReduceFromModelParallelRegion(torch.autograd.Function):
"""All-reduce the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce(input_)
@staticmethod
def forward(ctx, input_):
return _reduce(input_)
@staticmethod
def backward(ctx, grad_output):
return grad_output
class _ScatterToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_):
return _split_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
return _split_along_last_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_last_dim(grad_output)
class _GatherFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_):
return _gather_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
return _gather_along_last_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _split_along_last_dim(grad_output)
class _ScatterToSequenceParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_):
return _split_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
return _split_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_first_dim(grad_output)
class _GatherFromSequenceParallelRegion(torch.autograd.Function):
"""Gather the input from sequence parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_, tensor_parallel_output_grad=True):
return _gather_along_first_dim(input_)
@staticmethod
def forward(ctx, input_, tensor_parallel_output_grad=True):
ctx.tensor_parallel_output_grad = tensor_parallel_output_grad
return _gather_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
tensor_parallel_output_grad = ctx.tensor_parallel_output_grad
# If the computation graph after the gather operation is
# in the tensor parallel mode, output gradients need to reduce
# scattered and whereas if the computation is duplicated,
# output gradients need to be scattered.
if tensor_parallel_output_grad:
return _reduce_scatter_along_first_dim(grad_output), None
else:
return _split_along_first_dim(grad_output), None
class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
"""Reduce scatter the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce_scatter_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
return _reduce_scatter_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_first_dim(grad_output)
# -----------------
# Helper functions.
# -----------------
def copy_to_tensor_model_parallel_region(input_):
return _CopyToModelParallelRegion.apply(input_)
def reduce_from_tensor_model_parallel_region(input_):
return _ReduceFromModelParallelRegion.apply(input_)
def scatter_to_tensor_model_parallel_region(input_):
return _ScatterToModelParallelRegion.apply(input_)
def gather_from_tensor_model_parallel_region(input_):
return _GatherFromModelParallelRegion.apply(input_)
def scatter_to_sequence_parallel_region(input_):
return _ScatterToSequenceParallelRegion.apply(input_)
def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True):
return _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad)
def reduce_scatter_to_sequence_parallel_region(input_):
return _ReduceScatterToSequenceParallelRegion.apply(input_)
# Copyright 2023 The vLLM team.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/random.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
import contextlib
import torch
from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
)
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
def _set_cuda_rng_state(new_state, device=-1):
"""Sets the random number generator state of the current GPU.
Argumentss:
new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases.
"""
if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):
# older PyTorch
def cb():
with device_ctx_manager(device):
_C._cuda_setRNGState(new_state)
else:
# newer PyTorch
if device == -1:
device = torch.device('cuda')
elif isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device('cuda', device)
def cb():
idx = device.index
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
default_generator.set_state(new_state)
_lazy_call(cb)
class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states.
Using the `add` method, a cuda rng state is initialized based on
the input `seed` and is assigned to `name`. Later, by forking the
rng state, we can perform operations and return to our starting
cuda state.
"""
def __init__(self):
# Map from a string name to the cuda rng state.
self.states_ = {}
# Seeds are just for book keeping and ensure no seed is set twice.
self.seeds_ = set()
def reset(self):
"""Set to the initial state (no tracker)."""
self.states_ = {}
self.seeds_ = set()
def get_states(self):
"""Get rng states. Copy the dictionary so we have direct
pointers to the states, not just a pointer to the dictionary."""
states = {}
for name in self.states_:
states[name] = self.states_[name]
return states
def set_states(self, states):
"""Set the rng states. For efficiency purposes, we do not check
the size of seed for compatibility."""
self.states_ = states
def add(self, name, seed):
"""Track the rng state."""
# Check seed is not already used.
if seed in self.seeds_:
raise Exception('seed {} already exists'.format(seed))
self.seeds_.add(seed)
# Check that state is not already defined.
if name in self.states_:
raise Exception('cuda rng state {} already exists'.format(name))
# Get the current rng state.
orig_rng_state = torch.cuda.get_rng_state()
# Set the new state and store it.
torch.cuda.manual_seed(seed)
self.states_[name] = torch.cuda.get_rng_state()
# Reset rng state to what it was.
_set_cuda_rng_state(orig_rng_state)
@contextlib.contextmanager
def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
"""Fork the cuda rng state, perform operations, and exit with
the original state."""
# Check if we have added the state
if name not in self.states_:
raise Exception('cuda rng state {} is not added'.format(name))
# Store current rng state.
orig_cuda_rng_state = torch.cuda.get_rng_state()
# Set rng state to the desired one
_set_cuda_rng_state(self.states_[name])
# Do the stuff we wanted to do.
try:
yield
finally:
# Update the current rng state for later use.
self.states_[name] = torch.cuda.get_rng_state()
# And set the state to the original state we started with.
_set_cuda_rng_state(orig_cuda_rng_state)
# RNG tracker object.
_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
def get_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _CUDA_RNG_STATE_TRACKER
def model_parallel_cuda_manual_seed(seed):
"""Initialize model parallel cuda seed.
This function should be called after the model parallel is
initialized. Also, no torch.cuda.manual_seed should be called
after this function. Basically, this is replacement for that
function.
Two set of RNG states are tracked:
default state: This is for data parallelism and is the same among a
set of model parallel GPUs but different across
different model paralle groups. This is used for
example for dropout in the non-tensor-model-parallel regions.
tensor-model-parallel state: This state is different among a set of model
parallel GPUs, but the same across data parallel
groups. This is used for example for dropout in
model parallel regions.
"""
# 2718 is just for fun and any POSITIVE value will work.
offset = seed + 2718
tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank()
# Data parallel gets the original seed.
data_parallel_seed = seed
_CUDA_RNG_STATE_TRACKER.reset()
# Set the default state.
torch.cuda.manual_seed(data_parallel_seed)
# and model parallel state.
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME,
tensor_model_parallel_seed)
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from typing import List, Sequence
import torch import torch
from typing import List, Sequence
def ensure_divisibility(numerator, denominator): def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator.""" """Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format( assert numerator % denominator == 0, "{} is not divisible by {}".format(
numerator, denominator numerator, denominator)
)
def divide(numerator, denominator): def divide(numerator, denominator):
...@@ -56,15 +57,14 @@ class VocabUtility: ...@@ -56,15 +57,14 @@ class VocabUtility:
@staticmethod @staticmethod
def vocab_range_from_per_partition_vocab_size( def vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size: int, rank, world_size: int per_partition_vocab_size: int, rank: int) -> Sequence[int]:
) -> Sequence[int]:
index_f = rank * per_partition_vocab_size index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size index_l = index_f + per_partition_vocab_size
return index_f, index_l return index_f, index_l
@staticmethod @staticmethod
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int) -> Sequence[int]: def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int,
world_size: int) -> Sequence[int]:
per_partition_vocab_size = divide(global_vocab_size, world_size) per_partition_vocab_size = divide(global_vocab_size, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size( return VocabUtility.vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size, rank, world_size per_partition_vocab_size, rank)
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment