Commit ecd710b6 authored by wangzhengtao's avatar wangzhengtao
Browse files

init push

parents
This diff is collapsed.
This diff is collapsed.
import torch
class KimiAContent:
def __init__(
self, audio_token_ids=None, text_token_ids=None, is_continuous_mask=None
):
self.audio_token_ids: list[int] = audio_token_ids or []
self.text_token_ids: list[int] = text_token_ids or []
self.is_continuous_mask: list[int] = is_continuous_mask or []
self.continuous_feature = []
def audio_append(self, index: int, is_continuous: bool = False):
self.audio_token_ids.append(index)
self.is_continuous_mask.append(is_continuous)
def text_append(self, index: int):
self.text_token_ids.append(index)
def audio_extend(self, ids: list[int], is_continuous: bool = False):
self.audio_token_ids.extend(ids)
self.is_continuous_mask.extend([is_continuous] * len(ids))
def text_extend(self, ids: list[int]):
self.text_token_ids.extend(ids)
def audio_prepend(self, index: int, is_continuous: bool = False):
self.audio_token_ids = [index] + self.audio_token_ids
self.is_continuous_mask = [is_continuous] + self.is_continuous_mask
def text_prepend(self, index: int):
self.text_token_ids = [index] + self.text_token_ids
def audio_pretend(self, ids: list[int], is_continuous: bool = False):
self.audio_token_ids = ids + self.audio_token_ids
self.is_continuous_mask = [is_continuous] * len(ids) + self.is_continuous_mask
def text_pretend(self, ids: list[int]):
self.text_token_ids = ids + self.text_token_ids
def merge(self, other: "KimiAContent"):
self.audio_token_ids.extend(other.audio_token_ids)
self.text_token_ids.extend(other.text_token_ids)
self.is_continuous_mask.extend(other.is_continuous_mask)
self.continuous_feature.extend(other.continuous_feature)
def to_tensor(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return (
torch.tensor([self.audio_token_ids], dtype=torch.long),
torch.tensor([self.text_token_ids], dtype=torch.long),
torch.tensor([self.is_continuous_mask], dtype=torch.bool),
)
def is_valid(self):
return (
len(self.audio_token_ids)
== len(self.text_token_ids)
== len(self.is_continuous_mask)
)
This diff is collapsed.
from dataclasses import dataclass
@dataclass
class ExtraTokens:
msg_end: int
user_msg_start: int
assistant_msg_start: int
media_begin: int
media_end: int
kimia_text_blank: int
kimia_text_eos: int
kimia_user_msg_start: int
kimia_assistant_msg_start: int
kimia_speech_ct_id: int
kimia_speech_ctd_id: int
pad: int
def instantiate_extra_tokens(tokenizer):
if hasattr(tokenizer, "special_tokens"):
map_fn = lambda x: tokenizer.special_tokens[x]
elif hasattr(tokenizer, "convert_tokens_to_ids"):
map_fn = lambda x: tokenizer.convert_tokens_to_ids(x)
else:
raise ValueError(f"Invalid tokenizer type: {type(tokenizer)}")
return ExtraTokens(
msg_end=map_fn("<|im_msg_end|>"), # 0
user_msg_start=map_fn("<|im_user_msg_start|>"), # 1
assistant_msg_start=map_fn("<|im_assistant_msg_start|>"), # 2
media_begin=map_fn("<|im_media_begin|>"), # 13
media_end=map_fn("<|im_media_end|>"), # 15
kimia_text_blank=map_fn("<|im_kimia_text_blank|>"), # 18
kimia_text_eos=map_fn("<|im_kimia_text_eos|>"), # 19
kimia_user_msg_start=map_fn("<|im_kimia_user_msg_start|>"), # 22
kimia_assistant_msg_start=map_fn("<|im_kimia_assistant_msg_start|>"), # 23
kimia_speech_ct_id=map_fn("<|im_kimia_speech_ct_id|>"), # 27
kimia_speech_ctd_id=map_fn("<|im_kimia_speech_ctd_id|>"), # 28
pad=tokenizer.pad_id,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment