Commit 6a583c2f authored by chenych's avatar chenych
Browse files

update dtk to 24.04.1 and modify README

parent 7d576a9a
...@@ -279,7 +279,8 @@ def apply_rotary_pos_emb(x, freqs, position_ids, use_yarn, yarn_scale_factor, at ...@@ -279,7 +279,8 @@ def apply_rotary_pos_emb(x, freqs, position_ids, use_yarn, yarn_scale_factor, at
data_type = x.dtype data_type = x.dtype
rot_dim = freqs.shape[-1] rot_dim = freqs.shape[-1]
freqs = freqs[position_ids] freqs = freqs[position_ids]
freqs = freqs.view(x.shape[0],freqs.shape[1],freqs.shape[2],freqs.shape[4]) # feqs [b*s, 1, 1, head_dim]
freqs = freqs.view(x.shape[0],freqs.shape[1],freqs.shape[2])
x, x_pass = x[..., :rot_dim], x[..., rot_dim:] x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
mscale = float(_yarn_get_mscale(yarn_scale_factor) * attn_factor) if use_yarn else 1.0 mscale = float(_yarn_get_mscale(yarn_scale_factor) * attn_factor) if use_yarn else 1.0
x = (x * freqs.cos() * mscale) + (_rotate_half(x) * freqs.sin() * mscale) x = (x * freqs.cos() * mscale) + (_rotate_half(x) * freqs.sin() * mscale)
...@@ -297,7 +298,7 @@ class YuanRotaryEmbedding(nn.Module): ...@@ -297,7 +298,7 @@ class YuanRotaryEmbedding(nn.Module):
seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset
freqs = einsum('i , j -> i j', seq.float(), self.inv_freq) freqs = einsum('i , j -> i j', seq.float(), self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
return emb[:, None, None, :].float() return emb[:, None, :].float()
class LocalizedFiltering(torch.nn.Module): class LocalizedFiltering(torch.nn.Module):
""" """
...@@ -313,7 +314,7 @@ class LocalizedFiltering(torch.nn.Module): ...@@ -313,7 +314,7 @@ class LocalizedFiltering(torch.nn.Module):
self.lf_conv2d_num_pad = 0 self.lf_conv2d_num_pad = 0
self.conv1 = torch.nn.Conv2d(self.embed_dim, self.embed_dim // 2, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group) self.conv1 = torch.nn.Conv2d(self.embed_dim, self.embed_dim // 2, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group)
self.conv2 = torch.nn.Conv2d(self.embed_dim // 2, self.embed_dim, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group) self.conv2 = torch.nn.Conv2d(self.embed_dim // 2, self.embed_dim, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group)
self.output_layernorm = RMSNorm(self.embed_dim, eps=1e-6) self.output_layernorm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
def forward(self, inputs, lf1_cache, lf2_cache): def forward(self, inputs, lf1_cache, lf2_cache):
inputs = inputs.permute([1, 0, 2]) # [ s, b, h] inputs = inputs.permute([1, 0, 2]) # [ s, b, h]
...@@ -422,6 +423,7 @@ class YuanAttention(nn.Module): ...@@ -422,6 +423,7 @@ class YuanAttention(nn.Module):
linear_method=linear_method, linear_method=linear_method,
) )
self.model_type = getattr(config, 'model_type', 'yuan')
self.lf_gate = LocalizedFiltering(self.config, self.hidden_size) self.lf_gate = LocalizedFiltering(self.config, self.hidden_size)
self.attn = Attention(self.num_kv_heads, self.attn = Attention(self.num_kv_heads,
self.attn_head_size, self.attn_head_size,
...@@ -434,7 +436,8 @@ class YuanAttention(nn.Module): ...@@ -434,7 +436,8 @@ class YuanAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
rotary_pos_emb: torch.Tensor, rotary_pos_emb: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
lf_cache: LFCache, lf1_cache: torch.Tensor,
lf2_cache: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
use_yarn: bool=False, use_yarn: bool=False,
yarn_scale_factor: float=1.0, yarn_scale_factor: float=1.0,
...@@ -443,33 +446,28 @@ class YuanAttention(nn.Module): ...@@ -443,33 +446,28 @@ class YuanAttention(nn.Module):
q_len, hidden_size = hidden_states.size() q_len, hidden_size = hidden_states.size()
bsz = attn_metadata.num_prefills + attn_metadata.num_decode_tokens bsz = attn_metadata.num_prefills + attn_metadata.num_decode_tokens
positions = positions.view(bsz, -1)
lf1_cache, lf2_cache = lf_cache
v, _ = self.v_proj(hidden_states) v, _ = self.v_proj(hidden_states)
v = v.view(*v.shape[:-1], self.num_heads, self.attn_head_size) v = v.view(*v.shape[:-1], self.num_heads, self.attn_head_size)
if attn_metadata.prefill_metadata: result = []
lf1_cache_shape = (bsz, self.total_num_kv_heads * self.head_dim, 1, 1) if attn_metadata.prefill_metadata != None:
lf2_cache_shape = (bsz, self.total_num_kv_heads * self.head_dim // 2, 1, 1) for b in range(bsz):
lf1 = torch.zeros(lf1_cache_shape, dtype=torch.bfloat16, device=hidden_states.device) tmp_hidden_states, lf1, lf2 = self.lf_gate(hidden_states[attn_metadata.prefill_metadata.seq_start_loc[b]:attn_metadata.prefill_metadata.seq_start_loc[b+1]].unsqueeze(0), lf1_cache[b:b+1], lf2_cache[b:b+1])
lf2 = torch.zeros(lf2_cache_shape, dtype=torch.bfloat16, device=hidden_states.device) if lf1_cache != None and lf2_cache != None:
lf1_cache[b:b+1].copy_(lf1)
lf2_cache[b:b+1].copy_(lf2)
result.append(tmp_hidden_states.view(-1, *tmp_hidden_states.shape[2:]))
hidden_states = torch.cat(result, dim=0)
else: else:
lf1 = lf1_cache[:bsz, :, :, :] hidden_states = hidden_states.view(bsz, -1, hidden_size)
lf2 = lf2_cache[:bsz, :, :, :] hidden_states, lf1, lf2 = self.lf_gate(hidden_states, lf1_cache, lf2_cache)
hidden_states = hidden_states.view(bsz, -1, hidden_size) if lf1_cache != None and lf2_cache != None:
hidden_states, lf1, lf2 = self.lf_gate(hidden_states, lf1, lf2) lf1_cache.copy_(lf1)
if lf1_cache is not None and lf2_cache is not None: lf2_cache.copy_(lf2)
cache_ops.lf_reshape_and_cache(
lf1,
lf2,
lf1_cache,
lf2_cache
)
hidden_states = hidden_states.contiguous().view(-1, hidden_states.shape[-1]) hidden_states = hidden_states.contiguous().view(-1, hidden_states.shape[-1])
qk, _ = self.qk_proj(hidden_states) qk, _ = self.qk_proj(hidden_states)
qk = qk.view(*qk.shape[:-1], self.num_heads, int(qk.shape[-1] // self.num_heads)) qk = qk.view(*qk.shape[:-1], self.num_heads, int(qk.shape[-1] // self.num_heads))
(q, k) = torch.chunk(qk, 2, dim=-1) (q, k) = torch.chunk(qk, 2, dim=-1)
q = q.view(bsz, -1, *q.shape[1:])
k = k.view(bsz, -1, *k.shape[1:])
q = apply_rotary_pos_emb(q , rotary_pos_emb, positions, use_yarn, yarn_scale_factor, attn_factor, attn_metadata) q = apply_rotary_pos_emb(q , rotary_pos_emb, positions, use_yarn, yarn_scale_factor, attn_factor, attn_metadata)
k = apply_rotary_pos_emb(k , rotary_pos_emb, positions, use_yarn, yarn_scale_factor, attn_factor, attn_metadata) k = apply_rotary_pos_emb(k , rotary_pos_emb, positions, use_yarn, yarn_scale_factor, attn_factor, attn_metadata)
v = v.view(*v.shape[:-2], self.num_heads * self.attn_head_size) v = v.view(*v.shape[:-2], self.num_heads * self.attn_head_size)
...@@ -520,7 +518,8 @@ class YuanDecoderLayer(nn.Module): ...@@ -520,7 +518,8 @@ class YuanDecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
rotary_pos_emb: torch.Tensor, rotary_pos_emb: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
lf_cache: LFCache, lf1_cache: torch.Tensor,
lf2_cache: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
use_yarn: bool=False, use_yarn: bool=False,
yarn_scale_factor: float=1.0, yarn_scale_factor: float=1.0,
...@@ -534,7 +533,8 @@ class YuanDecoderLayer(nn.Module): ...@@ -534,7 +533,8 @@ class YuanDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
rotary_pos_emb=rotary_pos_emb, rotary_pos_emb=rotary_pos_emb,
kv_cache=kv_cache, kv_cache=kv_cache,
lf_cache=lf_cache, lf1_cache=lf1_cache,
lf2_cache=lf2_cache,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
use_yarn=use_yarn, use_yarn=use_yarn,
yarn_scale_factor=yarn_scale_factor, yarn_scale_factor=yarn_scale_factor,
...@@ -599,7 +599,8 @@ class YuanModel(nn.Module): ...@@ -599,7 +599,8 @@ class YuanModel(nn.Module):
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
lf_caches: List[LFCache], lf1_caches: List[torch.Tensor],
lf2_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
...@@ -610,7 +611,8 @@ class YuanModel(nn.Module): ...@@ -610,7 +611,8 @@ class YuanModel(nn.Module):
hidden_states, hidden_states,
rotary_pos_emb, rotary_pos_emb,
kv_caches[i], kv_caches[i],
lf_caches[i], lf1_caches[i],
lf2_caches[i],
attn_metadata, attn_metadata,
self.use_yarn, self.use_yarn,
self.yarn_scale_factor, self.yarn_scale_factor,
...@@ -654,15 +656,16 @@ class YuanForCausalLM(nn.Module): ...@@ -654,15 +656,16 @@ class YuanForCausalLM(nn.Module):
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
lf_caches: List[LFCache], lf1_caches: List[torch.Tensor],
lf2_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
if attn_metadata.prefill_metadata == None: if attn_metadata.prefill_metadata == None:
bsz = attn_metadata.num_decode_tokens bsz = attn_metadata.num_decode_tokens
else: else:
bsz = attn_metadata.num_prefills bsz = attn_metadata.num_prefills
positions = positions.view(bsz, -1) #bsz = 1
hidden_states = self.model(input_ids, positions, kv_caches, lf_caches, attn_metadata) hidden_states = self.model(input_ids, positions, kv_caches, lf1_caches, lf2_caches, attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
......
"""Sequence and its related classes.""" """Sequence and its related classes."""
import copy import copy
import enum import enum
import torch
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Union from typing import TYPE_CHECKING, Dict, List, Optional, Union, Tuple
from vllm.block import LogicalTokenBlock from vllm.block import LogicalTokenBlock
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available
if TYPE_CHECKING: if TYPE_CHECKING:
import torch import torch
...@@ -221,6 +223,7 @@ class Sequence: ...@@ -221,6 +223,7 @@ class Sequence:
self.logical_token_blocks: List[LogicalTokenBlock] = [] self.logical_token_blocks: List[LogicalTokenBlock] = []
# Initialize the logical token blocks with the prompt token ids. # Initialize the logical token blocks with the prompt token ids.
# zhaoxd 初始化blocks的时候只添加prompt_token_ids
self._append_tokens_to_blocks(prompt_token_ids) self._append_tokens_to_blocks(prompt_token_ids)
self.status = SequenceStatus.WAITING self.status = SequenceStatus.WAITING
self.stop_reason: Union[int, str, None] = None self.stop_reason: Union[int, str, None] = None
...@@ -231,6 +234,9 @@ class Sequence: ...@@ -231,6 +234,9 @@ class Sequence:
# Input + output tokens # Input + output tokens
self.tokens: Optional[List[str]] = None self.tokens: Optional[List[str]] = None
self.lf1_caches = []
self.lf2_caches = []
@property @property
def lora_int_id(self) -> int: def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0 return self.lora_request.lora_int_id if self.lora_request else 0
...@@ -248,6 +254,7 @@ class Sequence: ...@@ -248,6 +254,7 @@ class Sequence:
# TODO: The current hashing function is O(L^2). We should optimize # TODO: The current hashing function is O(L^2). We should optimize
# this in the future. # this in the future.
num_tokens = self.num_hashed_tokens_of_block(logical_idx) num_tokens = self.num_hashed_tokens_of_block(logical_idx)
# zhaoxd 利用输入的token ids和lora int id组成的元组,生成hash值
return hash( return hash(
(tuple(self.data.get_token_ids()[0:num_tokens]), self.lora_int_id)) (tuple(self.data.get_token_ids()[0:num_tokens]), self.lora_int_id))
...@@ -281,6 +288,26 @@ class Sequence: ...@@ -281,6 +288,26 @@ class Sequence:
num_empty_slots]) num_empty_slots])
cursor += num_empty_slots cursor += num_empty_slots
def get_lf_cache_shape(self, hidden_size) -> Tuple[int, int, int, int]:
return (1, hidden_size, 1, 1)
def create_lf_caches(
self,
hidden_size: int,
num_layers: int,
device: str = 'cuda',
dtype: torch.dtype = torch.float32,
) -> None:
"""Allocates LF cache on the specified device."""
lf1_cache_shape = self.get_lf_cache_shape(hidden_size)
lf2_cache_shape = self.get_lf_cache_shape(hidden_size // 2)
pin_memory = is_pin_memory_available() if device == "cpu" else False
for _ in range(num_layers):
lf1_cache = torch.zeros(lf1_cache_shape, dtype=dtype, pin_memory=pin_memory, device=device)
lf2_cache = torch.zeros(lf2_cache_shape, dtype=dtype, pin_memory=pin_memory, device=device)
self.lf1_caches.append(lf1_cache)
self.lf2_caches.append(lf2_cache)
def append_token_id( def append_token_id(
self, self,
token_id: int, token_id: int,
...@@ -576,6 +603,8 @@ class SequenceGroupMetadata: ...@@ -576,6 +603,8 @@ class SequenceGroupMetadata:
computed_block_nums: Optional[List[int]] = None, computed_block_nums: Optional[List[int]] = None,
state: Optional[SequenceGroupState] = None, state: Optional[SequenceGroupState] = None,
multi_modal_data: Optional[MultiModalData] = None, multi_modal_data: Optional[MultiModalData] = None,
lf1_caches: List[List[torch.Tensor]] = None,
lf2_caches: List[List[torch.Tensor]] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.is_prompt = is_prompt self.is_prompt = is_prompt
...@@ -587,6 +616,8 @@ class SequenceGroupMetadata: ...@@ -587,6 +616,8 @@ class SequenceGroupMetadata:
self.multi_modal_data = multi_modal_data self.multi_modal_data = multi_modal_data
self.state = SequenceGroupState() if state is None else state self.state = SequenceGroupState() if state is None else state
self._token_chunk_size = token_chunk_size self._token_chunk_size = token_chunk_size
self.lf1_caches = lf1_caches
self.lf2_caches = lf2_caches
if self._token_chunk_size is None: if self._token_chunk_size is None:
if is_prompt: if is_prompt:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment