Unverified Commit ba0bfd40 authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

TP/quantization/weight loading refactor part 1 - Simplify parallel linear logic (#1181)

parent 84e4e37d
...@@ -28,4 +28,4 @@ jobs: ...@@ -28,4 +28,4 @@ jobs:
pip install pylint==2.8.2 pip install pylint==2.8.2
- name: Analysing the code with pylint - name: Analysing the code with pylint
run: | run: |
pylint vllm pylint vllm tests
...@@ -28,4 +28,4 @@ jobs: ...@@ -28,4 +28,4 @@ jobs:
pip install toml==0.10.2 pip install toml==0.10.2
- name: Running yapf - name: Running yapf
run: | run: |
yapf --diff --recursive vllm --exclude 'vllm/model_executor/parallel_utils/**' yapf --diff --recursive vllm tests
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
[MASTER] [MASTER]
# Files or directories to be skipped. They should be base names, not paths. # Files or directories to be skipped. They should be base names, not paths.
ignore=docs,parallel_utils ignore=docs
# Files or directories matching the regex patterns are skipped. The regex # Files or directories matching the regex patterns are skipped. The regex
# matches against base names, not paths. # matches against base names, not paths.
......
...@@ -44,7 +44,6 @@ YAPF_FLAGS=( ...@@ -44,7 +44,6 @@ YAPF_FLAGS=(
YAPF_EXCLUDES=( YAPF_EXCLUDES=(
'--exclude' 'build/**' '--exclude' 'build/**'
'--exclude' 'vllm/model_executor/parallel_utils/**'
) )
# Format specified files # Format specified files
...@@ -72,7 +71,7 @@ format_changed() { ...@@ -72,7 +71,7 @@ format_changed() {
# Format all files # Format all files
format_all() { format_all() {
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" vllm yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" vllm tests
} }
## This flag formats individual files. --files *must* be the first command line ## This flag formats individual files. --files *must* be the first command line
...@@ -96,7 +95,7 @@ echo 'vLLM yapf: Done' ...@@ -96,7 +95,7 @@ echo 'vLLM yapf: Done'
# Run Pylint # Run Pylint
echo 'vLLM Pylint:' echo 'vLLM Pylint:'
pylint vllm pylint vllm tests
if ! git diff --quiet &>/dev/null; then if ! git diff --quiet &>/dev/null; then
echo 'Reformatted files. Please review and stage the changes.' echo 'Reformatted files. Please review and stage the changes.'
......
...@@ -14,6 +14,7 @@ app = vllm.entrypoints.api_server.app ...@@ -14,6 +14,7 @@ app = vllm.entrypoints.api_server.app
class AsyncLLMEngineWithStats(AsyncLLMEngine): class AsyncLLMEngineWithStats(AsyncLLMEngine):
# pylint: disable=redefined-outer-name
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._num_aborts = 0 self._num_aborts = 0
......
...@@ -24,6 +24,7 @@ def _query_server(prompt: str) -> dict: ...@@ -24,6 +24,7 @@ def _query_server(prompt: str) -> dict:
def api_server(): def api_server():
script_path = Path(__file__).parent.joinpath( script_path = Path(__file__).parent.joinpath(
"api_server_async_engine.py").absolute() "api_server_async_engine.py").absolute()
# pylint: disable=consider-using-with
uvicorn_process = subprocess.Popen([ uvicorn_process = subprocess.Popen([
sys.executable, "-u", sys.executable, "-u",
str(script_path), "--model", "facebook/opt-125m" str(script_path), "--model", "facebook/opt-125m"
...@@ -32,6 +33,7 @@ def api_server(): ...@@ -32,6 +33,7 @@ def api_server():
uvicorn_process.terminate() uvicorn_process.terminate()
# pylint: disable=redefined-outer-name, unused-argument
def test_api_server(api_server): def test_api_server(api_server):
""" """
Run the API server and test it. Run the API server and test it.
...@@ -47,6 +49,7 @@ def test_api_server(api_server): ...@@ -47,6 +49,7 @@ def test_api_server(api_server):
prompts = ["Hello world"] * 1 prompts = ["Hello world"] * 1
result = None result = None
while not result: while not result:
# pylint: disable=bare-except
try: try:
for result in pool.map(_query_server, prompts): for result in pool.map(_query_server, prompts):
break break
......
...@@ -32,12 +32,12 @@ class MockEngine: ...@@ -32,12 +32,12 @@ class MockEngine:
self.request_id = None self.request_id = None
def add_request(self, **kwargs): def add_request(self, **kwargs):
del kwargs # Unused
self.add_request_calls += 1 self.add_request_calls += 1
return
def abort_request(self, request_id): def abort_request(self, request_id):
del request_id # Unused
self.abort_request_calls += 1 self.abort_request_calls += 1
return
class MockAsyncLLMEngine(AsyncLLMEngine): class MockAsyncLLMEngine(AsyncLLMEngine):
......
...@@ -7,22 +7,22 @@ from vllm.outputs import RequestOutput ...@@ -7,22 +7,22 @@ from vllm.outputs import RequestOutput
class DummyEvent: class DummyEvent:
def __init__(self): def __init__(self):
self._flag = False self.flag = False
def set(self): def set(self):
self._flag = True self.flag = True
def clear(self): def clear(self):
self._flag = False self.flag = False
def test_request_tracker(): def test_request_tracker():
tracker = RequestTracker() tracker = RequestTracker()
tracker.new_requests_event = DummyEvent() tracker.new_requests_event = DummyEvent()
stream_1 = tracker.add_request("1") stream_1 = tracker.add_request("1")
assert tracker.new_requests_event._flag assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests() new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event._flag assert not tracker.new_requests_event.flag
assert len(new) == 1 assert len(new) == 1
assert new[0]["request_id"] == "1" assert new[0]["request_id"] == "1"
assert not finished assert not finished
...@@ -30,9 +30,9 @@ def test_request_tracker(): ...@@ -30,9 +30,9 @@ def test_request_tracker():
stream_2 = tracker.add_request("2") stream_2 = tracker.add_request("2")
stream_3 = tracker.add_request("3") stream_3 = tracker.add_request("3")
assert tracker.new_requests_event._flag assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests() new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event._flag assert not tracker.new_requests_event.flag
assert len(new) == 2 assert len(new) == 2
assert new[0]["request_id"] == "2" assert new[0]["request_id"] == "2"
assert new[1]["request_id"] == "3" assert new[1]["request_id"] == "3"
...@@ -43,7 +43,7 @@ def test_request_tracker(): ...@@ -43,7 +43,7 @@ def test_request_tracker():
# request_ids must be unique # request_ids must be unique
with pytest.raises(KeyError): with pytest.raises(KeyError):
tracker.add_request("1") tracker.add_request("1")
assert not tracker.new_requests_event._flag assert not tracker.new_requests_event.flag
tracker.abort_request("1") tracker.abort_request("1")
new, finished = tracker.get_new_and_finished_requests() new, finished = tracker.get_new_and_finished_requests()
...@@ -54,7 +54,7 @@ def test_request_tracker(): ...@@ -54,7 +54,7 @@ def test_request_tracker():
stream_4 = tracker.add_request("4") stream_4 = tracker.add_request("4")
tracker.abort_request("4") tracker.abort_request("4")
assert tracker.new_requests_event._flag assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests() new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1 assert len(finished) == 1
assert "4" in finished assert "4" in finished
...@@ -62,11 +62,11 @@ def test_request_tracker(): ...@@ -62,11 +62,11 @@ def test_request_tracker():
assert stream_4.finished assert stream_4.finished
stream_5 = tracker.add_request("5") stream_5 = tracker.add_request("5")
assert tracker.new_requests_event._flag assert tracker.new_requests_event.flag
tracker.process_request_output( tracker.process_request_output(
RequestOutput("2", "output", [], [], finished=True)) RequestOutput("2", "output", [], [], finished=True))
new, finished = tracker.get_new_and_finished_requests() new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event._flag assert not tracker.new_requests_event.flag
assert len(finished) == 1 assert len(finished) == 1
assert "2" in finished assert "2" in finished
assert len(new) == 1 assert len(new) == 1
......
...@@ -8,6 +8,7 @@ from vllm import LLM, SamplingParams ...@@ -8,6 +8,7 @@ from vllm import LLM, SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
_TEST_PROMPTS = [ _TEST_PROMPTS = [
# pylint: disable=line-too-long
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.", "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.", "Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.", "Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
......
"""Test the communication operators.
Run `pytest tests/distributed/test_comm_ops.py --forked`.
"""
from multiprocessing import Process
import pytest
import torch
from vllm.config import ParallelConfig
from vllm.engine.ray_utils import get_open_port
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce,
tensor_model_parallel_all_gather,
)
from vllm.worker.worker import _init_distributed_environment
def init_test_distributed_environment(pipeline_parallel_size: int,
tensor_parallel_size: int, rank: int,
distributed_init_port: str):
parallel_config = ParallelConfig(pipeline_parallel_size,
tensor_parallel_size,
worker_use_ray=True)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
torch.cuda.set_device(rank)
_init_distributed_environment(parallel_config, rank,
distributed_init_method)
def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
distributed_init_port: str):
init_test_distributed_environment(1, tensor_parallel_size, rank,
distributed_init_port)
num_elements = 8
all_tensors = [
torch.arange(num_elements, dtype=torch.float32, device="cuda") *
(r + 1) for r in range(tensor_parallel_size)
]
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
t = all_tensors[rank]
t = tensor_model_parallel_all_reduce(t)
assert torch.allclose(t, expected)
def all_gather_test_worker(tensor_parallel_size: int, rank: int,
distributed_init_port: str):
init_test_distributed_environment(1, tensor_parallel_size, rank,
distributed_init_port)
num_dimensions = 3
tensor_size = list(range(2, num_dimensions + 2))
total_size = 1
for s in tensor_size:
total_size *= s
for all_gather_dimension in range(num_dimensions):
all_tensors = [
torch.arange(total_size, dtype=torch.float32,
device="cuda").reshape(tensor_size) * (r + 1)
for r in range(tensor_parallel_size)
]
expected = torch.cat(all_tensors, dim=all_gather_dimension)
t = all_tensors[rank]
t = tensor_model_parallel_all_gather(t, all_gather_dimension)
assert torch.allclose(t, expected)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("tensor_parallel_size", [2])
@pytest.mark.parametrize("test_target",
[all_reduce_test_worker, all_gather_test_worker])
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
distributed_init_port = get_open_port()
processes = []
for rank in range(tensor_parallel_size):
p = Process(target=test_target,
args=(tensor_parallel_size, rank, distributed_init_port))
p.start()
processes.append(p)
for p in processes:
p.join()
assert all(p.exitcode == 0 for p in processes)
...@@ -5,6 +5,7 @@ from transformers import AutoTokenizer ...@@ -5,6 +5,7 @@ from transformers import AutoTokenizer
from vllm.transformers_utils.tokenizer import detokenize_incrementally from vllm.transformers_utils.tokenizer import detokenize_incrementally
TRUTH = [ TRUTH = [
# pylint: disable=line-too-long
"Hello here, this is a simple test", "Hello here, this is a simple test",
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving",
"我很感谢你的热情" "我很感谢你的热情"
......
...@@ -29,8 +29,8 @@ def test_silu_and_mul( ...@@ -29,8 +29,8 @@ def test_silu_and_mul(
) -> None: ) -> None:
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda') x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda")
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda') out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
activation_ops.silu_and_mul(out, x) activation_ops.silu_and_mul(out, x)
ref_out = ref_silu_and_mul(x) ref_out = ref_silu_and_mul(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
...@@ -49,8 +49,8 @@ def test_gelu_new( ...@@ -49,8 +49,8 @@ def test_gelu_new(
) -> None: ) -> None:
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda') x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda') out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
activation_ops.gelu_new(out, x) activation_ops.gelu_new(out, x)
ref_out = get_activation("gelu_new")(x) ref_out = get_activation("gelu_new")(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
...@@ -68,8 +68,8 @@ def test_gelu_fast( ...@@ -68,8 +68,8 @@ def test_gelu_fast(
) -> None: ) -> None:
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda') x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda') out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
activation_ops.gelu_fast(out, x) activation_ops.gelu_fast(out, x)
ref_out = get_activation("gelu_fast")(x) ref_out = get_activation("gelu_fast")(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
...@@ -106,14 +106,14 @@ def test_reshape_and_cache( ...@@ -106,14 +106,14 @@ def test_reshape_and_cache(
# Create a random slot mapping. # Create a random slot mapping.
num_slots = block_size * num_blocks num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens) slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda') slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device="cuda")
qkv = torch.randn(num_tokens, qkv = torch.randn(num_tokens,
3, 3,
num_heads, num_heads,
head_size, head_size,
dtype=dtype, dtype=dtype,
device='cuda') device="cuda")
_, key, value = qkv.unbind(dim=1) _, key, value = qkv.unbind(dim=1)
# Create the KV caches. # Create the KV caches.
...@@ -132,7 +132,7 @@ def test_reshape_and_cache( ...@@ -132,7 +132,7 @@ def test_reshape_and_cache(
# Run the reference implementation. # Run the reference implementation.
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
block_indicies = block_indicies.cpu().tolist() block_indicies = block_indicies.cpu().tolist()
block_offsets = slot_mapping % block_size block_offsets = slot_mapping % block_size
block_offsets = block_offsets.cpu().tolist() block_offsets = block_offsets.cpu().tolist()
......
...@@ -140,7 +140,7 @@ def test_rotary_embedding( ...@@ -140,7 +140,7 @@ def test_rotary_embedding(
cos = freqs.cos() cos = freqs.cos()
sin = freqs.sin() sin = freqs.sin()
cos_sin_cache = torch.cat((cos, sin), dim=-1) cos_sin_cache = torch.cat((cos, sin), dim=-1)
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda') cos_sin_cache = cos_sin_cache.to(dtype=dtype, device="cuda")
# Run the kernel. The kernel is in-place, so we need to clone the inputs. # Run the kernel. The kernel is in-place, so we need to clone the inputs.
out_query = query.clone() out_query = query.clone()
......
# pylint: disable=protected-access
import pytest import pytest
import random import random
from typing import Tuple from typing import Tuple
...@@ -108,7 +109,7 @@ def test_sampler_all_random(seed: int): ...@@ -108,7 +109,7 @@ def test_sampler_all_random(seed: int):
def test_sampler_all_beam(seed: int): def test_sampler_all_beam(seed: int):
set_random_seed(seed) set_random_seed(seed)
batch_size = random.randint(1, 256) batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size) input_tensor, _, sampler, worker = _prepare_test(batch_size)
seq_group_metadata_list = [] seq_group_metadata_list = []
for i in range(batch_size): for i in range(batch_size):
......
from vllm.model_executor.layers.quantized_linear.awq import ( from vllm.model_executor.layers.quantized_linear.awq import (
AWQColumnParallelLinear, AWQRowParallelLinear) AWQColumnParallelLinear, AWQRowParallelLinear)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
ColumnParallelLinear, RowParallelLinear) RowParallelLinear)
_QUANTIZED_LINEAR_REGISTRY = { _QUANTIZED_LINEAR_REGISTRY = {
"awq": (AWQColumnParallelLinear, AWQRowParallelLinear), "awq": (AWQColumnParallelLinear, AWQRowParallelLinear),
......
...@@ -4,8 +4,8 @@ import torch ...@@ -4,8 +4,8 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import quantization_ops from vllm import quantization_ops
from vllm.model_executor.parallel_utils.tensor_parallel.layers import ( from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
ColumnParallelLinear, RowParallelLinear) RowParallelLinear)
class AWQColumnParallelLinear(ColumnParallelLinear): class AWQColumnParallelLinear(ColumnParallelLinear):
......
...@@ -5,8 +5,8 @@ import torch ...@@ -5,8 +5,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.communication_op import (
gather_from_tensor_model_parallel_region) tensor_model_parallel_all_gather)
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceOutputs from vllm.sequence import SamplerOutput, SequenceData, SequenceOutputs
...@@ -92,7 +92,7 @@ def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor, ...@@ -92,7 +92,7 @@ def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
logits = torch.matmul(hidden_states, embedding.t()) logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None: if embedding_bias is not None:
logits += embedding_bias logits += embedding_bias
logits = gather_from_tensor_model_parallel_region(logits) logits = tensor_model_parallel_all_gather(logits)
# Remove paddings in vocab (if any). # Remove paddings in vocab (if any).
logits = logits[:, :vocab_size] logits = logits[:, :vocab_size]
return logits return logits
......
...@@ -39,8 +39,9 @@ from vllm.model_executor.weight_utils import ( ...@@ -39,8 +39,9 @@ from vllm.model_executor.weight_utils import (
load_tensor_parallel_weights) load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.aquila import AquilaConfig from vllm.transformers_utils.configs.aquila import AquilaConfig
...@@ -56,16 +57,18 @@ class AquilaMLP(nn.Module): ...@@ -56,16 +57,18 @@ class AquilaMLP(nn.Module):
hidden_act: str, hidden_act: str,
): ):
super().__init__() super().__init__()
self.gate_up_proj = ColumnParallelLinear(hidden_size, self.gate_up_proj = ColumnParallelLinear(
2 * intermediate_size, hidden_size,
bias=False, 2 * intermediate_size,
gather_output=False, bias=False,
perform_initialization=False) gather_output=False,
self.down_proj = RowParallelLinear(intermediate_size, )
hidden_size, self.down_proj = RowParallelLinear(
bias=False, intermediate_size,
input_is_parallel=True, hidden_size,
perform_initialization=False) bias=False,
input_is_parallel=True,
)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
...@@ -130,14 +133,12 @@ class AquilaAttention(nn.Module): ...@@ -130,14 +133,12 @@ class AquilaAttention(nn.Module):
self.head_dim, self.head_dim,
bias=False, bias=False,
gather_output=False, gather_output=False,
perform_initialization=False,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False,
) )
self.attn = PagedAttentionWithRoPE( self.attn = PagedAttentionWithRoPE(
self.num_heads, self.num_heads,
...@@ -230,7 +231,7 @@ class AquilaModel(nn.Module): ...@@ -230,7 +231,7 @@ class AquilaModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
perform_initialization=False) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
AquilaDecoderLayer(config) for _ in range(config.num_hidden_layers) AquilaDecoderLayer(config) for _ in range(config.num_hidden_layers)
]) ])
...@@ -270,11 +271,12 @@ class AquilaForCausalLM(nn.Module): ...@@ -270,11 +271,12 @@ class AquilaForCausalLM(nn.Module):
self.config = config self.config = config
self.model = AquilaModel(config) self.model = AquilaModel(config)
vocab_size = ((config.vocab_size + 63) // 64) * 64 vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ColumnParallelLinear(config.hidden_size, self.lm_head = ColumnParallelLinear(
vocab_size, config.hidden_size,
bias=False, vocab_size,
gather_output=False, bias=False,
perform_initialization=False) gather_output=False,
)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
def forward( def forward(
......
...@@ -39,8 +39,9 @@ from vllm.model_executor.weight_utils import ( ...@@ -39,8 +39,9 @@ from vllm.model_executor.weight_utils import (
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights) load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
...@@ -81,16 +82,18 @@ class BaiChuanMLP(nn.Module): ...@@ -81,16 +82,18 @@ class BaiChuanMLP(nn.Module):
hidden_act: str, hidden_act: str,
): ):
super().__init__() super().__init__()
self.gate_up_proj = ColumnParallelLinear(hidden_size, self.gate_up_proj = ColumnParallelLinear(
2 * intermediate_size, hidden_size,
bias=False, 2 * intermediate_size,
gather_output=False, bias=False,
perform_initialization=False) gather_output=False,
self.down_proj = RowParallelLinear(intermediate_size, )
hidden_size, self.down_proj = RowParallelLinear(
bias=False, intermediate_size,
input_is_parallel=True, hidden_size,
perform_initialization=False) bias=False,
input_is_parallel=True,
)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
...@@ -133,14 +136,12 @@ class BaiChuanAttention(nn.Module): ...@@ -133,14 +136,12 @@ class BaiChuanAttention(nn.Module):
3 * hidden_size, 3 * hidden_size,
bias=False, bias=False,
gather_output=False, gather_output=False,
perform_initialization=False,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False,
) )
# Create the alibi slopes and slice them. # Create the alibi slopes and slice them.
if self.postion_embedding == "ALIBI": if self.postion_embedding == "ALIBI":
...@@ -249,7 +250,7 @@ class BaiChuanModel(nn.Module): ...@@ -249,7 +250,7 @@ class BaiChuanModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
perform_initialization=False) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
BaiChuanDecoderLayer(config, position_embedding) BaiChuanDecoderLayer(config, position_embedding)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
...@@ -288,11 +289,12 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -288,11 +289,12 @@ class BaiChuanBaseForCausalLM(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.model = BaiChuanModel(config, position_embedding) self.model = BaiChuanModel(config, position_embedding)
self.lm_head = ColumnParallelLinear(config.hidden_size, self.lm_head = ColumnParallelLinear(
config.vocab_size, config.hidden_size,
bias=False, config.vocab_size,
gather_output=False, bias=False,
perform_initialization=False) gather_output=False,
)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
def forward( def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment