Unverified Commit 21d93c14 authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

Optimize Mixtral with expert parallelism (#2090)

parent f1c85201
......@@ -41,14 +41,6 @@ ENV NVCC_THREADS=$nvcc_threads
RUN python3 setup.py build_ext --inplace
# Build the megablocks library as wheel because it doesn't publish pre-built wheels.
# https://github.com/stanford-futuredata/megablocks/commit/5897cd6f254b7b3edf7a708a3a3314ecb54b6f78
RUN apt-get install -y git && \
git clone https://github.com/stanford-futuredata/megablocks.git && \
cd megablocks && \
git checkout 5897cd6f254b7b3edf7a708a3a3314ecb54b6f78 && \
MAX_JOBS=8 NVCC_THREADS=8 python3 setup.py bdist_wheel
# image to run unit testing suite
FROM dev AS test
......@@ -85,12 +77,8 @@ FROM vllm-base AS vllm-openai
RUN --mount=type=cache,target=/root/.cache/pip \
pip install accelerate
COPY vllm vllm
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
COPY --from=build /workspace/megablocks/dist/*.whl /tmp/
RUN --mount=type=cache,target=/root/.cache/pip \
pip install /tmp/megablocks-0.5.0-cp310-cp310-linux_x86_64.whl && \
rm /tmp/megablocks-0.5.0-cp310-cp310-linux_x86_64.whl
COPY vllm vllm
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
......@@ -72,10 +72,6 @@ Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/get
```bash
pip install vllm
```
**NOTE:** The Mixtral model additionally requires `megablocks` which can be installed with pip or [from source](https://github.com/stanford-futuredata/megablocks):
```bash
pip install megablocks
```
## Getting Started
......
......@@ -74,8 +74,7 @@ Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for in
Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-project/vllm/issues>`_ project.
.. note::
Currently, the ROCm version of vLLM does not support Mixtral.
Additionally, it only supports Mistral for context lengths up to 4096.
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
.. tip::
The easiest way to check if your model is supported is to run the program below:
......
......@@ -120,14 +120,16 @@ class ModelConfig:
if load_format == "auto":
load_format = "pt"
# FIXME(woosuk): This is a temporary hack. Support safetensor weights.
# TODO: Remove this check once HF updates the pt weights of Mixtral.
architectures = getattr(self.hf_config, "architectures", [])
if "MixtralForCausalLM" in architectures and load_format != "pt":
logger.info(
"Currently, only 'pt' format is supported for Mixtral. "
"Changing the format to 'pt'. This may re-download the "
"weights if you have downloaded the safetensor weights.")
load_format = "pt"
if "MixtralForCausalLM" in architectures:
if load_format == "pt":
raise ValueError(
"Currently, the 'pt' format is not supported for Mixtral. "
"Please use the 'safetensors' format instead. ")
elif load_format == "auto":
# Do not fall back to pt weights.
load_format = "safetensors"
self.load_format = load_format
......
......@@ -39,13 +39,15 @@ _MODELS = {
}
# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS = ["MixtralForCausalLM"]
_ROCM_UNSUPPORTED_MODELS = []
# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
"MistralForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
"MixtralForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
}
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment