"docs/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "0fe1c647442d415542623dd33dcdecccef19fcb8"
Commit 086a9d1c authored by Azure-Tang's avatar Azure-Tang
Browse files

Add vendor control

parent c009512a
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <torch/torch.h> #include <torch/torch.h>
#include <cstdint> #include <cstdint>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
typedef hip_bfloat16 nv_bfloat16;
__global__ void dequantize_q8_0_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) { __global__ void dequantize_q8_0_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
......
...@@ -31,6 +31,7 @@ from ktransformers.models.modeling_mixtral import MixtralForCausalLM ...@@ -31,6 +31,7 @@ from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.util.utils import prefill_and_generate, get_compute_capability from ktransformers.util.utils import prefill_and_generate, get_compute_capability
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
custom_models = { custom_models = {
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM, "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
...@@ -169,7 +170,7 @@ def local_chat( ...@@ -169,7 +170,7 @@ def local_chat(
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \ assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
"please change max_seq_len in ~/.ktransformers/config.yaml" "please change max_seq_len in ~/.ktransformers/config.yaml"
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8: if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8 and device_manager.gpu_vendor == GPUVendor.NVIDIA:
generated = prefill_and_generate( generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size, model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size,
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
......
...@@ -20,8 +20,14 @@ from ktransformers.util.utils import get_compute_capability ...@@ -20,8 +20,14 @@ from ktransformers.util.utils import get_compute_capability
import logging import logging
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
from flash_attn import flash_attn_func from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
try:
from flash_attn import flash_attn_func
except:
pass
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
from ktransformers.operators.triton_attention_prefill import context_attention_fwd
import os import os
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
if flashinfer_enabled: if flashinfer_enabled:
...@@ -319,18 +325,27 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): ...@@ -319,18 +325,27 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1) key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)
value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim) value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
attn_output = flash_attn_func( # for bsz = 1
query_states, attn_output = torch.zeros(bsz * q_len, self.num_heads, self.v_head_dim, device=hidden_states.device)
key_states, b_start_loc = torch.zeros(bsz, dtype=torch.int64, device=hidden_states.device)
value_states_padded, b_seq_len = torch.full((bsz,), q_len, dtype=torch.int64, device=hidden_states.device)
softmax_scale=self.softmax_scale,
causal=True, max_input_len = q_len
context_attention_fwd(
q=query_states.squeeze(0).view(-1, self.num_heads, self.q_head_dim),
k=key_states.squeeze(0).view(-1, self.num_heads, self.q_head_dim),
v=value_states.squeeze(0).view(-1, self.num_heads, self.v_head_dim),
o=attn_output,
b_start_loc=b_start_loc,
b_seq_len=b_seq_len,
max_input_len=max_input_len,
is_causal=True
) )
if self.q_head_dim != self.v_head_dim: if self.q_head_dim != self.v_head_dim:
attn_output = attn_output[:, :, :, : self.v_head_dim] attn_output = attn_output[:, :, : self.v_head_dim]
attn_output = attn_output.reshape( attn_output = attn_output.reshape(
bsz, q_len, self.num_heads * self.v_head_dim bsz, q_len, self.num_heads * self.v_head_dim
...@@ -589,7 +604,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): ...@@ -589,7 +604,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if os.name == 'nt' or get_compute_capability()<8: if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA:
print("for Windows or GPU before ampere, use forward_windows") print("for Windows or GPU before ampere, use forward_windows")
return self.forward_windows( return self.forward_windows(
hidden_states, hidden_states,
......
...@@ -17,7 +17,10 @@ import logging ...@@ -17,7 +17,10 @@ import logging
logger = logging.getLogger("dynamic_attention") logger = logging.getLogger("dynamic_attention")
sys.path.append(os.path.dirname(__file__) + "/../ktransformers_ext/cpu_backend") sys.path.append(os.path.dirname(__file__) + "/../ktransformers_ext/cpu_backend")
from ktransformers.operators.cpuinfer import CPUInfer, CPUInferKVCache from ktransformers.operators.cpuinfer import CPUInfer, CPUInferKVCache
from flash_attn import flash_attn_func, flash_attn_with_kvcache try:
from flash_attn import flash_attn_func, flash_attn_with_kvcache
except:
print("falsh attn not found")
import math import math
......
...@@ -53,6 +53,7 @@ from ktransformers.models.modeling_deepseek import ( ...@@ -53,6 +53,7 @@ from ktransformers.models.modeling_deepseek import (
DeepseekV2DecoderLayer, DeepseekV2DecoderLayer,
DeepseekV2MoE, DeepseekV2MoE,
) )
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from ktransformers.models.configuration_llama import LlamaConfig from ktransformers.models.configuration_llama import LlamaConfig
from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule
...@@ -649,7 +650,7 @@ class KDeepseekV2Model(BaseInjectedModule): ...@@ -649,7 +650,7 @@ class KDeepseekV2Model(BaseInjectedModule):
if per_layer_prefill_flag: if per_layer_prefill_flag:
causal_mask = None causal_mask = None
else: else:
if os.name == 'nt' or get_compute_capability()<8: if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA:
print("for Windows or GPU before ampere, use forward_windows") print("for Windows or GPU before ampere, use forward_windows")
# only use mask in forward windows or can't flash attn # only use mask in forward windows or can't flash attn
causal_mask = self._update_causal_mask( causal_mask = self._update_causal_mask(
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import triton import triton
import triton.language as tl import triton.language as tl
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
@triton.jit @triton.jit
def tanh(x): def tanh(x):
# Tanh is just a scaled sigmoid # Tanh is just a scaled sigmoid
...@@ -181,8 +181,8 @@ def _decode_grouped_att_m_fwd( ...@@ -181,8 +181,8 @@ def _decode_grouped_att_m_fwd(
# [TODO] work around shmem limit on MI3xx # [TODO] work around shmem limit on MI3xx
# TODO: support hip # TODO: support hip
#if is_hip_ and Lk >= 576: if device_manager.gpu_vendor == GPUVendor.AMD and Lk >= 576:
# BLOCK = 16 BLOCK = 16
if Lk == 576: if Lk == 576:
BLOCK_DMODEL = 512 BLOCK_DMODEL = 512
......
# Adapted from
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py
# which was originally adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
"""
Memory-efficient attention for prefill.
It supporst page size = 1.
"""
# Adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
import torch
import triton
import triton.language as tl
is_cuda_available = torch.cuda.is_available()
if is_cuda_available:
CUDA_CAPABILITY = torch.cuda.get_device_capability()
@triton.jit
def _fwd_kernel(
Q,
K,
V,
sm_scale,
B_Start_Loc,
B_Seqlen,
Out,
stride_qbs,
stride_qh,
stride_kbs,
stride_kh,
stride_vbs,
stride_vh,
stride_obs,
stride_oh,
kv_group_num: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
IS_CAUSAL: tl.constexpr,
Lk: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_kv_head = cur_head // kv_group_num
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
+ cur_head * stride_qh
+ offs_d[None, :]
)
off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None]
off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :]
mask_d = offs_d < Lk
q = tl.load(
Q + off_q,
mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]),
other=0.0,
)
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
end_n = (
cur_batch_seq_len
if not IS_CAUSAL
else tl.minimum((start_m + 1) * BLOCK_M, cur_batch_seq_len)
)
for start_n in range(0, block_mask * end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (mask_d[:, None]),
other=0.0,
)
# mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
if IS_CAUSAL:
qk += tl.where(
(start_n + offs_n[None, :] < cur_batch_seq_len)
& (offs_m[:, None] >= (start_n + offs_n[None, :])),
0,
float("-inf"),
)
else:
qk += tl.where(
(start_n + offs_n[None, :]) < cur_batch_seq_len, 0, float("-inf")
)
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (mask_d[None, :]),
other=0.0,
)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
+ cur_head * stride_oh
+ offs_d[None, :]
)
out_ptrs = Out + off_o
tl.store(
out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :])
)
def context_attention_fwd(
q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
):
"""
q, k, v: [b * s, head, head_dim]
b_start_loc: [b]
b_seq_len: [b]
out: [b * s, head, head_dim]
"""
if is_cuda_available and CUDA_CAPABILITY[0] > 8:
BLOCK = 128
else:
BLOCK = 64
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
sm_scale = 1.0 / (Lq**0.5)
batch, head = b_seq_len.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k.shape[1]
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
num_warps = 4 if Lk <= 64 else 8
_fwd_kernel[grid](
q,
k,
v,
sm_scale,
b_start_loc,
b_seq_len,
o,
q.stride(0),
q.stride(1),
k.stride(0),
k.stride(1),
v.stride(0),
v.stride(1),
o.stride(0),
o.stride(1),
kv_group_num=kv_group_num,
BLOCK_M=BLOCK,
BLOCK_DMODEL=triton.next_power_of_2(Lk),
BLOCK_N=BLOCK,
IS_CAUSAL=is_causal,
num_warps=num_warps,
num_stages=1,
Lk=Lk,
)
\ No newline at end of file
from __future__ import annotations
from enum import IntEnum, auto
from typing import Optional, Union, List
import torch
class GPUVendor(IntEnum):
NVIDIA = auto()
AMD = auto()
MooreThreads = auto()
MetaX = auto()
MUSA = auto()
Unknown = auto()
class DeviceManager:
"""
Device manager that provides a unified interface for handling different GPU vendors
"""
def __init__(self):
self.gpu_vendor = self._detect_gpu_vendor()
self.available_devices = self._get_available_devices()
def _detect_gpu_vendor(self) -> GPUVendor:
"""Detect GPU vendor type"""
if not torch.cuda.is_available():
# Check MUSA availability (assuming a musa module exists)
try:
import musa
if musa.is_available():
return GPUVendor.MUSA
except (ImportError, AttributeError):
pass
return GPUVendor.Unknown
device_name = torch.cuda.get_device_name(0).lower()
if any(name in device_name for name in ["nvidia", "geforce", "quadro", "tesla", "titan", "rtx", "gtx"]):
return GPUVendor.NVIDIA
elif any(name in device_name for name in ["amd", "radeon", "rx", "vega", "instinct", "firepro", "mi"]):
return GPUVendor.AMD
elif any(name in device_name for name in ["mthreads", "moore", "mtt"]):
return GPUVendor.MooreThreads
elif any(name in device_name for name in ["metax", "meta"]):
return GPUVendor.MetaX
elif "musa" in device_name:
return GPUVendor.MUSA
# Backend check
try:
if hasattr(torch.version, 'hip') and torch.version.hip is not None:
return GPUVendor.AMD
elif hasattr(torch.version, 'cuda') and torch.version.cuda is not None:
return GPUVendor.NVIDIA
except:
pass
return GPUVendor.Unknown
def _get_available_devices(self) -> List[int]:
"""Get list of available device indices"""
devices = []
if self.gpu_vendor == GPUVendor.NVIDIA or self.gpu_vendor == GPUVendor.AMD:
devices = list(range(torch.cuda.device_count()))
elif self.gpu_vendor == GPUVendor.MUSA:
try:
import musa
devices = list(range(musa.device_count()))
except (ImportError, AttributeError):
pass
return devices
def get_device_str(self, device_id: Union[int, str]) -> str:
"""
Get device string for the given device ID
Args:
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
Device string representation (e.g., "cuda:0", "musa:1", "cpu")
"""
if device_id == -1 or device_id == "cpu":
return "cpu"
if isinstance(device_id, int):
if self.gpu_vendor == GPUVendor.NVIDIA or self.gpu_vendor == GPUVendor.AMD:
if device_id < torch.cuda.device_count():
return f"cuda:{device_id}"
elif self.gpu_vendor == GPUVendor.MUSA:
try:
import musa
if device_id < musa.device_count():
return f"musa:{device_id}"
except (ImportError, AttributeError):
pass
return "cpu"
def to_torch_device(self, device_id: Union[int, str] = 0) -> torch.device:
"""
Convert device ID to torch.device object
Args:
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
torch.device object
"""
device_str = self.get_device_str(device_id)
# Handle MUSA device
if device_str.startswith("musa:"):
try:
import musa
index = int(device_str.split(":")[-1])
return musa.device(index)
except (ImportError, ValueError, AttributeError):
return torch.device("cpu")
# Standard PyTorch device
return torch.device(device_str)
def move_tensor_to_device(self, tensor: torch.Tensor, device_id: Union[int, str] = 0) -> torch.Tensor:
"""
Move tensor to specified device
Args:
tensor: PyTorch tensor to move
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
Tensor moved to the specified device
"""
device = self.to_torch_device(device_id)
return tensor.to(device)
def is_available(self, index: int = 0) -> bool:
"""
Check if device at specified index is available
Args:
index: Device index to check
Returns:
True if the device is available, False otherwise
"""
if index < 0:
return True # CPU is always available
return index in self.available_devices
def get_all_devices(self) -> List[int]:
"""
Get all available device indices
Returns:
List of available device indices (0, 1, 2, etc.)
"""
return self.available_devices
# Create global device manager instance
device_manager = DeviceManager()
# Convenience functions
def get_device(device_id: Union[int, str] = 0) -> torch.device:
"""
Get torch.device object for the specified device ID
Args:
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
torch.device object
"""
return device_manager.to_torch_device(device_id)
def to_device(tensor: torch.Tensor, device_id: Union[int, str] = 0) -> torch.Tensor:
"""
Move tensor to specified device
Args:
tensor: PyTorch tensor to move
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
Tensor moved to the specified device
"""
return device_manager.move_tensor_to_device(tensor, device_id)
# Get devices
cpu_device = get_device(-1) # CPU using index -1
cpu_device2 = get_device("cpu") # CPU using string "cpu"
gpu0 = get_device(0) # First GPU
# Move tensors
x = torch.randn(3, 3)
x_gpu = to_device(x, 0) # Move to first GPU
x_cpu1 = to_device(x, -1) # Move to CPU using index -1
x_cpu2 = to_device(x, "cpu") # Move to CPU using string "cpu"
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment