Unverified Commit 9ef98d52 authored by Gerald's avatar Gerald Committed by GitHub
Browse files

[Model][MiniMaxText01] Support MiniMaxText01 model inference (#13454)


Signed-off-by: default avatarqscqesze <475517977@qq.com>
Co-authored-by: default avatarqingjun <qingjun@minimaxi.com>
Co-authored-by: default avatarqscqesze <475517977@qq.com>
parent 93491aef
......@@ -503,6 +503,11 @@ See [this page](#generative-models) for more information on how to use generativ
* `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.
* ✅︎
* ✅︎
- * `MiniMaxText01ForCausalLM`
* MiniMax-Text
* `MiniMaxAI/MiniMax-Text-01`, etc.
*
* ✅︎
- * `Zamba2ForCausalLM`
* Zamba2
* `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc.
......
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from vllm.model_executor.layers.lightning_attn import (
linear_decode_forward_triton)
from vllm.platforms import current_platform
NUM_HEADS = [4, 8]
HEAD_SIZES = [64]
BATCH_SIZES = [1, 2]
SEQ_LENGTHS = [16]
DTYPES = [torch.float32]
def reference_lightning_attention(q, k, v, ed, block_size, kv_history):
"""Reference implementation of lightning attention core algorithm
The difference from the main implementation is that this processes
each step sequentially, instead of using parallelized triton kernels
"""
B, H, S, D = q.shape
E = v.shape[-1]
dtype = q.dtype
output = torch.zeros((B, H, S, E), dtype=dtype, device=q.device)
# Use clone() to ensure an independent copy
if kv_history is None:
kv_cache = torch.zeros((B, H, D, E), dtype=dtype, device=q.device)
else:
kv_cache = kv_history.clone()
# More efficient implementation
# Convert decay factors to matrix form
if ed.dim() == 1:
decay = torch.exp(-ed).view(1, -1, 1, 1)
else:
decay = torch.exp(-ed)
for b in range(B):
for step in range(S):
# Process all heads at once for this position
q_bs = q[b, :, step] # [H, D]
k_bs = k[b, :, step] # [H, D]
v_bs = v[b, :, step] # [H, E]
# Calculate KV outer products for all heads
for h in range(H):
# Calculate KV outer product
kv_outer = torch.outer(k_bs[h], v_bs[h])
# Update KV cache with decay
# Note: Using the same order as in the Triton kernel
kv_cache[b, h] = decay[0, h, 0, 0] * kv_cache[b, h] + kv_outer
# Calculate attention output
output[b, h, step] = torch.matmul(q_bs[h], kv_cache[b, h])
# Match the shape returned by the actual implementation
# The actual implementation returns a tensor of shape [B, H, 2, D, E]
# where dimension 2 contains both KV and KV history
kv_reshaped = kv_cache.unsqueeze(2) # [B, H, 1, D, E]
final_kv_cache = torch.cat([kv_reshaped, kv_reshaped],
dim=2) # [B, H, 2, D, E]
return output, final_kv_cache
def reference_linear_decode(q, k, v, kv_caches, slope_rate, slot_idx):
"""Reference implementation: linear attention decode function"""
B, H, _, D = q.shape
output = torch.zeros(B, H * D, dtype=q.dtype, device=q.device)
# Calculate decay factors once (more efficient)
decay = torch.exp(-slope_rate).view(-1, 1, 1) # [H, 1, 1]
# Process each batch
for b in range(B):
slot_id = slot_idx[b].item()
# Skip padding positions
if slot_id == -1:
continue
# Process all heads at once for this batch
q_b = q[b, :, 0] # [H, D]
k_b = k[b, :, 0] # [H, D]
v_b = v[b, :, 0] # [H, D]
# Process each attention head
for h in range(H):
# Get current query, key and value
q_bh = q_b[h]
k_bh = k_b[h]
v_bh = v_b[h]
# Get cache
kv_cache_old = kv_caches[b, h]
# Calculate new key-value outer product
kv_outer = torch.outer(k_bh, v_bh)
# Apply decay and update cache
kv_new = kv_outer + decay[h, 0, 0] * kv_cache_old
# Calculate output
out_h = torch.matmul(q_bh, kv_new)
# Update output and cache
output[b, h * D:(h + 1) * D] = out_h
kv_caches[b, h] = kv_new
return output
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_linear_decode_forward_triton(
batch_size: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
):
torch.set_default_device("cuda")
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
current_platform.seed_everything(42)
base = 0.01
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
kv_caches = base * torch.randn(batch_size,
num_heads,
head_size,
head_size,
dtype=dtype,
device="cuda")
kv_caches_copy = kv_caches.clone()
slope_rate = torch.zeros(num_heads, device="cuda")
for h in range(num_heads):
slope_rate[h] = 0.1 * (h + 1)
slot_idx = torch.arange(batch_size, device="cuda")
triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
slope_rate, slot_idx)
reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
slope_rate, slot_idx)
torch.testing.assert_close(triton_output,
reference_output,
rtol=1e-1,
atol=1e-1)
torch.testing.assert_close(kv_caches, kv_caches_copy, rtol=1e-1, atol=1e-1)
assert triton_output.shape == (batch_size, num_heads * head_size)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_linear_decode_forward_triton_with_padding(
num_heads: int,
head_size: int,
dtype: torch.dtype,
):
torch.set_default_device("cuda")
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
current_platform.seed_everything(42)
batch_size = 4
base = 0.01
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
kv_caches = base * torch.randn(batch_size,
num_heads,
head_size,
head_size,
dtype=dtype,
device="cuda")
kv_caches_copy = kv_caches.clone()
slope_rate = torch.zeros(num_heads, device="cuda")
for h in range(num_heads):
slope_rate[h] = 0.1 * (h + 1)
slot_idx = torch.tensor([0, 1, -1, 2], device="cuda")
triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
slope_rate, slot_idx)
reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
slope_rate, slot_idx)
padding_mask = (slot_idx
!= -1).unsqueeze(1).expand(-1, num_heads * head_size)
triton_masked = triton_output[padding_mask]
reference_masked = reference_output[padding_mask]
atol, rtol = 1.5e-1, 1.5e-1
valid_indices = slot_idx != -1
for i in range(batch_size):
if valid_indices[i] > 0:
torch.testing.assert_close(kv_caches[i],
kv_caches_copy[i],
rtol=rtol,
atol=atol)
torch.testing.assert_close(triton_masked,
reference_masked,
rtol=rtol,
atol=atol)
assert triton_output.shape == (batch_size, num_heads * head_size)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENGTHS)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_lightning_attention_reference(
batch_size: int,
num_heads: int,
head_size: int,
seq_len: int,
dtype: torch.dtype,
):
torch.set_default_device("cuda")
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
current_platform.seed_everything(42)
base = 0.01
q = base * torch.randn(
batch_size, num_heads, seq_len, head_size, dtype=dtype)
k = base * torch.randn(
batch_size, num_heads, seq_len, head_size, dtype=dtype)
v = base * torch.randn(
batch_size, num_heads, seq_len, head_size, dtype=dtype)
ed = torch.zeros(num_heads, device="cuda")
for h in range(num_heads):
ed[h] = 0.1 * (h + 1)
kv_history = base * torch.randn(batch_size,
num_heads,
head_size,
head_size,
dtype=dtype,
device="cuda")
kv_history_clone = kv_history.clone()
ref_output, ref_kv_cache = reference_lightning_attention(
q, k, v, ed, 256, kv_history)
from vllm.model_executor.layers.lightning_attn import lightning_attention
actual_output, actual_kv_cache = lightning_attention(
q, k, v, ed, 256, kv_history_clone)
atol, rtol = 1.5e-1, 1.5e-1
torch.testing.assert_close(ref_output, actual_output, rtol=rtol, atol=atol)
torch.testing.assert_close(ref_kv_cache,
actual_kv_cache,
rtol=rtol,
atol=atol)
assert ref_output.shape == (batch_size, num_heads, seq_len, head_size)
assert ref_kv_cache.shape == actual_kv_cache.shape
......@@ -176,6 +176,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True),
"MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B",
trust_remote_code=True),
"MiniMaxText01ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01",
trust_remote_code=True),
"MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"),
"MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1"), # noqa: E501
"QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501
......
......@@ -971,15 +971,10 @@ class ModelConfig:
return sum(not bc.attention.no_op
for bc in block_configs[start:end])
else:
# Hybrid model
# Hybrid model Jamba
layers_block_type_value = getattr(self.hf_config,
"layers_block_type", None)
if layers_block_type_value is None:
raise ValueError("The model is an hybrid without a "
"layers_block_type in the hf_config, "
"cannot determine the num of "
f"{block_type.value} layers")
if layers_block_type_value is not None:
if hasattr(self.hf_text_config,
"model_type") and (self.hf_text_config.model_type
== "zamba2"):
......@@ -988,10 +983,23 @@ class ModelConfig:
for t in layers_block_type_value[start:end])
else:
return self.get_num_layers(parallel_config)
return sum(t == block_type.value
for t in layers_block_type_value[start:end])
# Hybrid model Minimax
attn_type_list = getattr(self.hf_config, "attn_type_list", None)
if attn_type_list:
return sum(t == 1 for t in attn_type_list[start:end])
if layers_block_type_value is None and attn_type_list is None:
raise ValueError(
"The model is an hybrid without a"
"layers_block_type or an attn_type_list in the hf_config,"
"cannot determine the num of "
f"{block_type.value} layers")
return sum(t == 1 for t in attn_type_list[start:end])
def get_multimodal_config(self) -> "MultiModalConfig":
"""
Get the multimodal configuration of the model.
......
......@@ -303,6 +303,9 @@ class _AsyncLLMEngine(LLMEngine):
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
if not scheduler_outputs.is_empty():
# this will cause mamba_cache/minimax_cache failed
# to release finished_requests_ids of the last steps
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple
import torch
from vllm.attention.backends.utils import PAD_SLOT_ID
class ConstantSizeCache(ABC):
"""
Abstract base class for managing constant size caches
like Mamba and Minimax.
"""
def __init__(self, max_batch_size: int):
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the cache
self.cache_indices_mapping: Dict[str, Dict[int, int]] = {}
self.free_cache_indices = list(range(max_batch_size))
@property
@abstractmethod
def cache(self) -> Any:
"""Return the underlying cache tensor(s)"""
pass
@abstractmethod
def _copy_cache(self, from_index: int, to_index: int):
"""Copy cache data from one index to another"""
pass
def current_run_tensors(self, **kwargs) -> Tuple:
"""
Return the tensors for the current run's conv and ssm state.
"""
if "seqlen_agnostic_capture_inputs" not in kwargs:
# We get here only on Prefill/Eager mode runs
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_cache(
request_ids_to_seq_ids, finished_requests_ids)
state_indices_tensor = torch.as_tensor(state_indices,
dtype=torch.int32,
device="cuda")
cache_tensors = self.cache
else:
# CUDA graph capturing runs
cache_tensors, state_indices_tensor = kwargs[
"seqlen_agnostic_capture_inputs"]
return (cache_tensors, state_indices_tensor)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
"""
Copy the relevant state_indices into the CUDA graph input buffer
"""
assert all(
key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
finished_requests_ids = kwargs["finished_requests_ids"]
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
assert "seqlen_agnostic_capture_inputs" in input_buffers
_, input_state_indices_buffer = input_buffers[
"seqlen_agnostic_capture_inputs"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_cache(
request_ids_to_seq_ids, finished_requests_ids)
cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
state_indices)
state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)
input_state_indices_buffer.copy_(
torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Cache during the CUDA graph replay
runs.
"""
state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size,
dtype=torch.int32,
device="cuda")
return (self.cache, state_indices_tensor)
def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
finished_requests_ids) -> int:
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
if cur_rid in finished_requests_ids:
# set as pad, do not allocate destination index
return PAD_SLOT_ID
elif cur_rid not in self.cache_indices_mapping:
destination_index = self.free_cache_indices.pop()
self.cache_indices_mapping[cur_rid] = {seq_id: destination_index}
return destination_index
elif seq_id not in (seq_ids2indices :=
self.cache_indices_mapping[cur_rid]):
# parallel sampling , where n > 1, assume prefill have
# already happened, so we copy the
# existing cache into the siblings seq_ids caches
index_exists = next(iter(seq_ids2indices.values()))
# case of decoding n>1, copy prefill cache to decoding indices
destination_index = self.free_cache_indices.pop()
self._copy_cache(from_index=index_exists,
to_index=destination_index)
self.cache_indices_mapping[cur_rid][seq_id] = destination_index
return destination_index
else:
return self.cache_indices_mapping[cur_rid][seq_id]
def _prepare_current_run_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]],
finished_requests_ids: List[str]) -> List[int]:
return [
self._assign_seq_id_to_cache_index(req_id, seq_id,
finished_requests_ids)
for req_id, seq_ids in request_ids_to_seq_ids.items()
for seq_id in seq_ids
]
def _release_finished_requests(self,
finished_seq_groups_req_ids: List[str]):
for req_id in finished_seq_groups_req_ids:
if req_id in self.cache_indices_mapping:
for seq_id in self.cache_indices_mapping[req_id]:
self.free_cache_indices.append(
self.cache_indices_mapping[req_id][seq_id])
self.cache_indices_mapping.pop(req_id)
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Dict, List, Tuple
from typing import Tuple
import torch
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.model_executor.models.constant_size_cache import ConstantSizeCache
@dataclass
......@@ -21,7 +22,7 @@ class MambaCacheParams:
self.state_indices_tensor)
class MambaCacheManager:
class MambaCacheManager(ConstantSizeCache):
def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype,
num_mamba_layers: int, conv_state_shape: Tuple[int, int],
......@@ -32,6 +33,9 @@ class MambaCacheManager:
if not vllm_config.model_config.enforce_eager:
max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size)
# Initialize parent class
super().__init__(max_batch_size)
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
conv_state_shape,
dtype=dtype,
......@@ -41,126 +45,32 @@ class MambaCacheManager:
dtype=dtype,
device="cuda")
self.mamba_cache = (conv_state, temporal_state)
self._mamba_cache = (conv_state, temporal_state)
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the self.mamba_cache
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
self.free_cache_indices = list(range(max_batch_size))
@property
def cache(self):
return self._mamba_cache
def _copy_cache(self, from_index: int, to_index: int):
for cache_t in self.cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)
def current_run_tensors(self, **kwargs) -> MambaCacheParams:
"""
Return the tensors for the current run's conv and ssm state.
"""
if "seqlen_agnostic_capture_inputs" not in kwargs:
# We get here only on Prefill/Eager mode runs
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_mamba_cache(
request_ids_to_seq_ids, finished_requests_ids)
state_indices_tensor = torch.as_tensor(state_indices,
dtype=torch.int32,
device="cuda")
mamba_cache_tensors = self.mamba_cache
else:
# CUDA graph capturing runs
(mamba_cache_tensors,
state_indices_tensor) = kwargs["seqlen_agnostic_capture_inputs"]
return MambaCacheParams(mamba_cache_tensors[0], mamba_cache_tensors[1],
cache_tensors, state_indices_tensor = super().current_run_tensors(
**kwargs)
return MambaCacheParams(cache_tensors[0], cache_tensors[1],
state_indices_tensor)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
"""
Copy the relevant state_indices into the CUDA graph input buffer
"""
assert all(
key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
finished_requests_ids = kwargs["finished_requests_ids"]
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
assert "seqlen_agnostic_capture_inputs" in input_buffers
_, input_state_indices_buffer = input_buffers[
"seqlen_agnostic_capture_inputs"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_mamba_cache(
request_ids_to_seq_ids, finished_requests_ids)
cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
state_indices)
state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)
input_state_indices_buffer.copy_(
torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs.
"""
state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size,
return self._mamba_cache, torch.as_tensor([PAD_SLOT_ID] * batch_size,
dtype=torch.int32,
device="cuda")
return (self.mamba_cache, state_indices_tensor)
def _copy_mamba_cache(self, from_index: int, to_index: int):
assert len(self.mamba_cache) > 0
for cache_t in self.mamba_cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)
def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
finished_requests_ids) -> int:
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
if cur_rid in finished_requests_ids:
# set as pad, do not allocate destination index
return PAD_SLOT_ID
elif cur_rid not in self.mamba_cache_indices_mapping:
destination_index = self.free_cache_indices.pop()
self.mamba_cache_indices_mapping[cur_rid] = {
seq_id: destination_index
}
return destination_index
elif seq_id not in (seq_ids2indices :=
self.mamba_cache_indices_mapping[cur_rid]):
# parallel sampling , where n > 1, assume prefill have
# already happened, so we copy the
# existing cache into the siblings seq_ids caches
index_exists = next(iter(seq_ids2indices.values()))
# case of decoding n>1, copy prefill cache to decoding indices
destination_index = self.free_cache_indices.pop()
self._copy_mamba_cache(from_index=index_exists,
to_index=destination_index)
self.mamba_cache_indices_mapping[cur_rid][
seq_id] = destination_index
return destination_index
else:
# already exists
return self.mamba_cache_indices_mapping[cur_rid][seq_id]
def _prepare_current_run_mamba_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]],
finished_requests_ids: List[str]) -> List[int]:
return [
self._assign_seq_id_to_cache_index(req_id, seq_id,
finished_requests_ids)
for req_id, seq_ids in request_ids_to_seq_ids.items()
for seq_id in seq_ids
]
def _release_finished_requests(self,
finished_seq_groups_req_ids: List[str]):
for req_id in finished_seq_groups_req_ids:
if req_id in self.mamba_cache_indices_mapping:
for seq_id in self.mamba_cache_indices_mapping[req_id]:
self.free_cache_indices.append(
self.mamba_cache_indices_mapping[req_id][seq_id])
self.mamba_cache_indices_mapping.pop(req_id)
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import torch
from vllm.model_executor.models.constant_size_cache import ConstantSizeCache
@dataclass
class MinimaxCacheParams:
minimax_cache: torch.Tensor = torch.Tensor()
state_indices_tensor: torch.Tensor = torch.Tensor()
def at_layer_idx(self, layer_idx):
return MinimaxCacheParams(self.minimax_cache[layer_idx, ...],
self.state_indices_tensor)
class MinimaxCacheManager(ConstantSizeCache):
def __init__(self, dtype, cache_shape):
super().__init__(cache_shape[1]) # max_batch_size is cache_shape[1]
self._minimax_cache = torch.empty(size=cache_shape,
dtype=dtype,
device="cuda")
@property
def cache(self):
return self._minimax_cache
def _copy_cache(self, from_index: int, to_index: int):
assert len(self.cache) > 0
for cache_t in self.cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)
This diff is collapsed.
......@@ -35,6 +35,7 @@ _TEXT_GENERATION_MODELS = {
"AquilaModel": ("llama", "LlamaForCausalLM"),
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
"MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
# baichuan-7b, upper case 'C' in the class name
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
# baichuan-13b, lower case 'c' in the class name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment