Unverified Commit 21d93c14 authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

Optimize Mixtral with expert parallelism (#2090)

parent f1c85201
...@@ -41,14 +41,6 @@ ENV NVCC_THREADS=$nvcc_threads ...@@ -41,14 +41,6 @@ ENV NVCC_THREADS=$nvcc_threads
RUN python3 setup.py build_ext --inplace RUN python3 setup.py build_ext --inplace
# Build the megablocks library as wheel because it doesn't publish pre-built wheels.
# https://github.com/stanford-futuredata/megablocks/commit/5897cd6f254b7b3edf7a708a3a3314ecb54b6f78
RUN apt-get install -y git && \
git clone https://github.com/stanford-futuredata/megablocks.git && \
cd megablocks && \
git checkout 5897cd6f254b7b3edf7a708a3a3314ecb54b6f78 && \
MAX_JOBS=8 NVCC_THREADS=8 python3 setup.py bdist_wheel
# image to run unit testing suite # image to run unit testing suite
FROM dev AS test FROM dev AS test
...@@ -85,12 +77,8 @@ FROM vllm-base AS vllm-openai ...@@ -85,12 +77,8 @@ FROM vllm-base AS vllm-openai
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
pip install accelerate pip install accelerate
COPY vllm vllm
COPY --from=build /workspace/vllm/*.so /workspace/vllm/ COPY --from=build /workspace/vllm/*.so /workspace/vllm/
COPY --from=build /workspace/megablocks/dist/*.whl /tmp/ COPY vllm vllm
RUN --mount=type=cache,target=/root/.cache/pip \
pip install /tmp/megablocks-0.5.0-cp310-cp310-linux_x86_64.whl && \
rm /tmp/megablocks-0.5.0-cp310-cp310-linux_x86_64.whl
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
...@@ -72,10 +72,6 @@ Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/get ...@@ -72,10 +72,6 @@ Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/get
```bash ```bash
pip install vllm pip install vllm
``` ```
**NOTE:** The Mixtral model additionally requires `megablocks` which can be installed with pip or [from source](https://github.com/stanford-futuredata/megablocks):
```bash
pip install megablocks
```
## Getting Started ## Getting Started
......
...@@ -74,8 +74,7 @@ Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for in ...@@ -74,8 +74,7 @@ Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for in
Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-project/vllm/issues>`_ project. Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-project/vllm/issues>`_ project.
.. note:: .. note::
Currently, the ROCm version of vLLM does not support Mixtral. Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
Additionally, it only supports Mistral for context lengths up to 4096.
.. tip:: .. tip::
The easiest way to check if your model is supported is to run the program below: The easiest way to check if your model is supported is to run the program below:
......
...@@ -120,14 +120,16 @@ class ModelConfig: ...@@ -120,14 +120,16 @@ class ModelConfig:
if load_format == "auto": if load_format == "auto":
load_format = "pt" load_format = "pt"
# FIXME(woosuk): This is a temporary hack. Support safetensor weights. # TODO: Remove this check once HF updates the pt weights of Mixtral.
architectures = getattr(self.hf_config, "architectures", []) architectures = getattr(self.hf_config, "architectures", [])
if "MixtralForCausalLM" in architectures and load_format != "pt": if "MixtralForCausalLM" in architectures:
logger.info( if load_format == "pt":
"Currently, only 'pt' format is supported for Mixtral. " raise ValueError(
"Changing the format to 'pt'. This may re-download the " "Currently, the 'pt' format is not supported for Mixtral. "
"weights if you have downloaded the safetensor weights.") "Please use the 'safetensors' format instead. ")
load_format = "pt" elif load_format == "auto":
# Do not fall back to pt weights.
load_format = "safetensors"
self.load_format = load_format self.load_format = load_format
......
...@@ -39,13 +39,15 @@ _MODELS = { ...@@ -39,13 +39,15 @@ _MODELS = {
} }
# Models not supported by ROCm. # Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS = ["MixtralForCausalLM"] _ROCM_UNSUPPORTED_MODELS = []
# Models partially supported by ROCm. # Models partially supported by ROCm.
# Architecture -> Reason. # Architecture -> Reason.
_ROCM_PARTIALLY_SUPPORTED_MODELS = { _ROCM_PARTIALLY_SUPPORTED_MODELS = {
"MistralForCausalLM": "MistralForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention", "Sliding window attention is not yet supported in ROCm's flash attention",
"MixtralForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
} }
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment