"official/r1/README.md" did not exist on "831281cedfc8a4a0ad7c0c37173963fafb99da37"
Commit 6a583c2f authored by chenych's avatar chenych
Browse files

update dtk to 24.04.1 and modify README

parent 7d576a9a
......@@ -279,7 +279,8 @@ def apply_rotary_pos_emb(x, freqs, position_ids, use_yarn, yarn_scale_factor, at
data_type = x.dtype
rot_dim = freqs.shape[-1]
freqs = freqs[position_ids]
freqs = freqs.view(x.shape[0],freqs.shape[1],freqs.shape[2],freqs.shape[4])
# feqs [b*s, 1, 1, head_dim]
freqs = freqs.view(x.shape[0],freqs.shape[1],freqs.shape[2])
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
mscale = float(_yarn_get_mscale(yarn_scale_factor) * attn_factor) if use_yarn else 1.0
x = (x * freqs.cos() * mscale) + (_rotate_half(x) * freqs.sin() * mscale)
......@@ -297,7 +298,7 @@ class YuanRotaryEmbedding(nn.Module):
seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset
freqs = einsum('i , j -> i j', seq.float(), self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return emb[:, None, None, :].float()
return emb[:, None, :].float()
class LocalizedFiltering(torch.nn.Module):
"""
......@@ -313,7 +314,7 @@ class LocalizedFiltering(torch.nn.Module):
self.lf_conv2d_num_pad = 0
self.conv1 = torch.nn.Conv2d(self.embed_dim, self.embed_dim // 2, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group)
self.conv2 = torch.nn.Conv2d(self.embed_dim // 2, self.embed_dim, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group)
self.output_layernorm = RMSNorm(self.embed_dim, eps=1e-6)
self.output_layernorm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
def forward(self, inputs, lf1_cache, lf2_cache):
inputs = inputs.permute([1, 0, 2]) # [ s, b, h]
......@@ -422,6 +423,7 @@ class YuanAttention(nn.Module):
linear_method=linear_method,
)
self.model_type = getattr(config, 'model_type', 'yuan')
self.lf_gate = LocalizedFiltering(self.config, self.hidden_size)
self.attn = Attention(self.num_kv_heads,
self.attn_head_size,
......@@ -434,7 +436,8 @@ class YuanAttention(nn.Module):
hidden_states: torch.Tensor,
rotary_pos_emb: torch.Tensor,
kv_cache: torch.Tensor,
lf_cache: LFCache,
lf1_cache: torch.Tensor,
lf2_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
use_yarn: bool=False,
yarn_scale_factor: float=1.0,
......@@ -443,33 +446,28 @@ class YuanAttention(nn.Module):
q_len, hidden_size = hidden_states.size()
bsz = attn_metadata.num_prefills + attn_metadata.num_decode_tokens
positions = positions.view(bsz, -1)
lf1_cache, lf2_cache = lf_cache
v, _ = self.v_proj(hidden_states)
v = v.view(*v.shape[:-1], self.num_heads, self.attn_head_size)
if attn_metadata.prefill_metadata:
lf1_cache_shape = (bsz, self.total_num_kv_heads * self.head_dim, 1, 1)
lf2_cache_shape = (bsz, self.total_num_kv_heads * self.head_dim // 2, 1, 1)
lf1 = torch.zeros(lf1_cache_shape, dtype=torch.bfloat16, device=hidden_states.device)
lf2 = torch.zeros(lf2_cache_shape, dtype=torch.bfloat16, device=hidden_states.device)
result = []
if attn_metadata.prefill_metadata != None:
for b in range(bsz):
tmp_hidden_states, lf1, lf2 = self.lf_gate(hidden_states[attn_metadata.prefill_metadata.seq_start_loc[b]:attn_metadata.prefill_metadata.seq_start_loc[b+1]].unsqueeze(0), lf1_cache[b:b+1], lf2_cache[b:b+1])
if lf1_cache != None and lf2_cache != None:
lf1_cache[b:b+1].copy_(lf1)
lf2_cache[b:b+1].copy_(lf2)
result.append(tmp_hidden_states.view(-1, *tmp_hidden_states.shape[2:]))
hidden_states = torch.cat(result, dim=0)
else:
lf1 = lf1_cache[:bsz, :, :, :]
lf2 = lf2_cache[:bsz, :, :, :]
hidden_states = hidden_states.view(bsz, -1, hidden_size)
hidden_states, lf1, lf2 = self.lf_gate(hidden_states, lf1, lf2)
if lf1_cache is not None and lf2_cache is not None:
cache_ops.lf_reshape_and_cache(
lf1,
lf2,
lf1_cache,
lf2_cache
)
hidden_states = hidden_states.view(bsz, -1, hidden_size)
hidden_states, lf1, lf2 = self.lf_gate(hidden_states, lf1_cache, lf2_cache)
if lf1_cache != None and lf2_cache != None:
lf1_cache.copy_(lf1)
lf2_cache.copy_(lf2)
hidden_states = hidden_states.contiguous().view(-1, hidden_states.shape[-1])
qk, _ = self.qk_proj(hidden_states)
qk = qk.view(*qk.shape[:-1], self.num_heads, int(qk.shape[-1] // self.num_heads))
(q, k) = torch.chunk(qk, 2, dim=-1)
q = q.view(bsz, -1, *q.shape[1:])
k = k.view(bsz, -1, *k.shape[1:])
q = apply_rotary_pos_emb(q , rotary_pos_emb, positions, use_yarn, yarn_scale_factor, attn_factor, attn_metadata)
k = apply_rotary_pos_emb(k , rotary_pos_emb, positions, use_yarn, yarn_scale_factor, attn_factor, attn_metadata)
v = v.view(*v.shape[:-2], self.num_heads * self.attn_head_size)
......@@ -520,7 +518,8 @@ class YuanDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
rotary_pos_emb: torch.Tensor,
kv_cache: torch.Tensor,
lf_cache: LFCache,
lf1_cache: torch.Tensor,
lf2_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
use_yarn: bool=False,
yarn_scale_factor: float=1.0,
......@@ -534,7 +533,8 @@ class YuanDecoderLayer(nn.Module):
hidden_states=hidden_states,
rotary_pos_emb=rotary_pos_emb,
kv_cache=kv_cache,
lf_cache=lf_cache,
lf1_cache=lf1_cache,
lf2_cache=lf2_cache,
attn_metadata=attn_metadata,
use_yarn=use_yarn,
yarn_scale_factor=yarn_scale_factor,
......@@ -599,7 +599,8 @@ class YuanModel(nn.Module):
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
lf_caches: List[LFCache],
lf1_caches: List[torch.Tensor],
lf2_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
......@@ -610,7 +611,8 @@ class YuanModel(nn.Module):
hidden_states,
rotary_pos_emb,
kv_caches[i],
lf_caches[i],
lf1_caches[i],
lf2_caches[i],
attn_metadata,
self.use_yarn,
self.yarn_scale_factor,
......@@ -654,15 +656,16 @@ class YuanForCausalLM(nn.Module):
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
lf_caches: List[LFCache],
lf1_caches: List[torch.Tensor],
lf2_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
if attn_metadata.prefill_metadata == None:
bsz = attn_metadata.num_decode_tokens
else:
bsz = attn_metadata.num_prefills
positions = positions.view(bsz, -1)
hidden_states = self.model(input_ids, positions, kv_caches, lf_caches, attn_metadata)
#bsz = 1
hidden_states = self.model(input_ids, positions, kv_caches, lf1_caches, lf2_caches, attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
......
"""Sequence and its related classes."""
import copy
import enum
import torch
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union, Tuple
from vllm.block import LogicalTokenBlock
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available
if TYPE_CHECKING:
import torch
......@@ -221,6 +223,7 @@ class Sequence:
self.logical_token_blocks: List[LogicalTokenBlock] = []
# Initialize the logical token blocks with the prompt token ids.
# zhaoxd 初始化blocks的时候只添加prompt_token_ids
self._append_tokens_to_blocks(prompt_token_ids)
self.status = SequenceStatus.WAITING
self.stop_reason: Union[int, str, None] = None
......@@ -231,6 +234,9 @@ class Sequence:
# Input + output tokens
self.tokens: Optional[List[str]] = None
self.lf1_caches = []
self.lf2_caches = []
@property
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
......@@ -248,6 +254,7 @@ class Sequence:
# TODO: The current hashing function is O(L^2). We should optimize
# this in the future.
num_tokens = self.num_hashed_tokens_of_block(logical_idx)
# zhaoxd 利用输入的token ids和lora int id组成的元组,生成hash值
return hash(
(tuple(self.data.get_token_ids()[0:num_tokens]), self.lora_int_id))
......@@ -281,6 +288,26 @@ class Sequence:
num_empty_slots])
cursor += num_empty_slots
def get_lf_cache_shape(self, hidden_size) -> Tuple[int, int, int, int]:
return (1, hidden_size, 1, 1)
def create_lf_caches(
self,
hidden_size: int,
num_layers: int,
device: str = 'cuda',
dtype: torch.dtype = torch.float32,
) -> None:
"""Allocates LF cache on the specified device."""
lf1_cache_shape = self.get_lf_cache_shape(hidden_size)
lf2_cache_shape = self.get_lf_cache_shape(hidden_size // 2)
pin_memory = is_pin_memory_available() if device == "cpu" else False
for _ in range(num_layers):
lf1_cache = torch.zeros(lf1_cache_shape, dtype=dtype, pin_memory=pin_memory, device=device)
lf2_cache = torch.zeros(lf2_cache_shape, dtype=dtype, pin_memory=pin_memory, device=device)
self.lf1_caches.append(lf1_cache)
self.lf2_caches.append(lf2_cache)
def append_token_id(
self,
token_id: int,
......@@ -576,6 +603,8 @@ class SequenceGroupMetadata:
computed_block_nums: Optional[List[int]] = None,
state: Optional[SequenceGroupState] = None,
multi_modal_data: Optional[MultiModalData] = None,
lf1_caches: List[List[torch.Tensor]] = None,
lf2_caches: List[List[torch.Tensor]] = None,
) -> None:
self.request_id = request_id
self.is_prompt = is_prompt
......@@ -587,6 +616,8 @@ class SequenceGroupMetadata:
self.multi_modal_data = multi_modal_data
self.state = SequenceGroupState() if state is None else state
self._token_chunk_size = token_chunk_size
self.lf1_caches = lf1_caches
self.lf2_caches = lf2_caches
if self._token_chunk_size is None:
if is_prompt:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment