Commit 66b809cc authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.7.2' into v0.7.2-dev

parents 37b63c24 0408efc6
# SPDX-License-Identifier: Apache-2.0
import os import os
from functools import lru_cache from functools import lru_cache
from typing import TYPE_CHECKING, Dict, List, Optional from typing import TYPE_CHECKING, Dict, List, Optional
...@@ -77,6 +79,9 @@ class RocmPlatform(Platform): ...@@ -77,6 +79,9 @@ class RocmPlatform(Platform):
def get_attn_backend_cls(cls, selected_backend, head_size, dtype, def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, kv_cache_dtype, block_size, use_v1,
use_mla) -> str: use_mla) -> str:
if use_mla:
logger.info("Using Triton MLA backend.")
return "vllm.attention.backends.triton_mla.TritonMLABackend"
selected_backend = (_Backend.ROCM_FLASH if selected_backend selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend) == _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH: if selected_backend == _Backend.ROCM_FLASH:
......
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
import torch import torch
......
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
import torch import torch
...@@ -66,9 +68,14 @@ class XPUPlatform(Platform): ...@@ -66,9 +68,14 @@ class XPUPlatform(Platform):
# check and update model config # check and update model config
model_config = vllm_config.model_config model_config = vllm_config.model_config
if model_config.dtype == torch.bfloat16: if model_config.dtype == torch.bfloat16:
logger.warning( bf16_supported = cls.device_support_bf16()
"bfloat16 is not fully supported on XPU, casting to float16.") if not bf16_supported:
model_config.dtype = torch.float16 logger.warning(
"bfloat16 is only supported on Intel Data Center GPU, "
"Intel Arc GPU is not supported yet. Your device is %s,"
"which is not supported. will fallback to float16",
cls.get_device_name())
model_config.dtype = torch.float16
if not model_config.enforce_eager: if not model_config.enforce_eager:
logger.warning( logger.warning(
"CUDA graph is not supported on XPU, fallback to the eager " "CUDA graph is not supported on XPU, fallback to the eager "
...@@ -116,3 +123,15 @@ class XPUPlatform(Platform): ...@@ -116,3 +123,15 @@ class XPUPlatform(Platform):
) -> float: ) -> float:
torch.xpu.reset_peak_memory_stats(device) torch.xpu.reset_peak_memory_stats(device)
return torch.xpu.max_memory_allocated(device) return torch.xpu.max_memory_allocated(device)
@classmethod
def device_support_bf16(cls) -> bool:
device_name = cls.get_device_name().lower()
if device_name.count("arc") > 0:
return False
elif device_name.count("data center gpu") > 0:
return True
else:
logger.warning("Unknown device name %s, always use float16",
device_name)
return False
# SPDX-License-Identifier: Apache-2.0
import logging import logging
import os import os
from typing import Callable, Dict from typing import Callable, Dict
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional from typing import Any, Optional
import msgspec import msgspec
......
# SPDX-License-Identifier: Apache-2.0
from .layerwise_profile import layerwise_profile from .layerwise_profile import layerwise_profile
__all__ = [ __all__ = [
......
# SPDX-License-Identifier: Apache-2.0
import copy import copy
from collections import defaultdict from collections import defaultdict
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
......
# SPDX-License-Identifier: Apache-2.0
import dataclasses import dataclasses
from typing import Callable, Dict, List, Type, Union from typing import Callable, Dict, List, Type, Union
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
......
# SPDX-License-Identifier: Apache-2.0
import logging import logging
import math import math
from typing import Any, Callable, Dict, List, Optional, Type from typing import Any, Callable, Dict, List, Optional, Type
......
# SPDX-License-Identifier: Apache-2.0
import msgspec import msgspec
from vllm.adapter_commons.request import AdapterRequest from vllm.adapter_commons.request import AdapterRequest
......
# SPDX-License-Identifier: Apache-2.0
# code borrowed from: https://github.com/huggingface/peft/blob/v0.12.0/src/peft/utils/save_and_load.py#L420 # code borrowed from: https://github.com/huggingface/peft/blob/v0.12.0/src/peft/utils/save_and_load.py#L420
import os import os
......
# SPDX-License-Identifier: Apache-2.0
import logging import logging
from typing import Any, Optional, Set, Type from typing import Any, Optional, Set, Type
......
# SPDX-License-Identifier: Apache-2.0
"""Sampling parameters for text generation.""" """Sampling parameters for text generation."""
import copy import copy
from dataclasses import dataclass from dataclasses import dataclass
......
# SPDX-License-Identifier: Apache-2.0
import functools import functools
import struct import struct
from dataclasses import dataclass from dataclasses import dataclass
......
# SPDX-License-Identifier: Apache-2.0
# The CLI entrypoint to vLLM. # The CLI entrypoint to vLLM.
import argparse import argparse
import os import os
......
# SPDX-License-Identifier: Apache-2.0
"""Sequence and its related classes.""" """Sequence and its related classes."""
import copy import copy
import enum import enum
......
# SPDX-License-Identifier: Apache-2.0
from array import array from array import array
from itertools import chain, count from itertools import chain, count
from typing import Iterator, List, Optional, Tuple from typing import Iterator, List, Optional, Tuple
......
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional from typing import List, Optional
import torch import torch
......
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Set, Union from typing import List, Optional, Set, Union
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment