Unverified Commit 77a34c28 authored by UnicornChan's avatar UnicornChan Committed by GitHub
Browse files

Merge pull request #36 from kvcache-ai/develop-0.1.2

Release v0.1.2
parents 44f57270 395cd3e7
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
* @Author : chenht2022 * @Author : chenht2022
* @Date : 2024-07-22 02:03:22 * @Date : 2024-07-22 02:03:22
* @Version : 1.0.0 * @Version : 1.0.0
* @LastEditors : chenht2022 * @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:35:10 * @LastEditTime : 2024-07-25 10:35:10
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/ **/
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "llama.cpp/ggml-quants.h" #include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h" #include "llama.cpp/ggml.h"
#include "llamafile/sgemm.h" #include "llamafile/sgemm.h"
#include "shared_mem_buffer.h"
struct MOEConfig { struct MOEConfig {
int expert_num; int expert_num;
...@@ -48,13 +49,13 @@ struct MOEConfig { ...@@ -48,13 +49,13 @@ struct MOEConfig {
class MOE { class MOE {
public: public:
MOE(MOEConfig); MOE(MOEConfig);
~MOE();
void warm_up(Backend* backend); void warm_up(Backend* backend);
void forward_one(int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend); void forward_one(int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend);
void forward_many(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend); void forward_many(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend);
void forward(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend); void forward(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend);
private: private:
static uint8_t* buffer_;
MOEConfig config_; MOEConfig config_;
void* gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)] void* gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]
void* up_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)] void* up_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]
......
/**
* @Description :
* @Author : chenht2022
* @Date : 2024-08-05 04:49:08
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-08-05 09:21:29
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#include "shared_mem_buffer.h"
#include <cstdio>
SharedMemBuffer::SharedMemBuffer() {
buffer_ = nullptr;
size_ = 0;
}
SharedMemBuffer::~SharedMemBuffer() {
if (buffer_) {
free(buffer_);
}
}
void SharedMemBuffer::alloc(void* object, std::vector<std::pair<void**, uint64_t>> requests) {
uint64_t size = 0;
for (auto& request : requests) {
size += request.second;
}
if (size > size_) {
if (buffer_) {
free(buffer_);
}
buffer_ = malloc(size);
size_ = size;
for (auto& obj_requests : hist_requests_) {
for (auto& requests : obj_requests.second) {
arrange(requests);
}
}
}
arrange(requests);
hist_requests_[object].push_back(requests);
}
void SharedMemBuffer::dealloc(void* object) {
hist_requests_.erase(object);
}
void SharedMemBuffer::arrange(std::vector<std::pair<void**, uint64_t>> requests) {
uint64_t offset = 0;
for (auto& request : requests) {
*(request.first) = (uint8_t*)buffer_ + offset;
offset += request.second;
}
}
/**
* @Description :
* @Author : chenht2022
* @Date : 2024-08-05 04:49:08
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-08-05 06:36:41
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_SHAREDMEMBUFFER_H
#define CPUINFER_SHAREDMEMBUFFER_H
#include <cstdint>
#include <cstdlib>
#include <map>
#include <vector>
class SharedMemBuffer {
public:
SharedMemBuffer();
~SharedMemBuffer();
void alloc(void* object, std::vector<std::pair<void**, uint64_t>> requests);
void dealloc(void* object);
private:
void* buffer_;
uint64_t size_;
std::map<void*, std::vector<std::vector<std::pair<void**, uint64_t>>>> hist_requests_;
void arrange(std::vector<std::pair<void**, uint64_t>> requests);
};
static SharedMemBuffer shared_mem_buffer;
#endif
\ No newline at end of file
...@@ -31,18 +31,21 @@ import fire ...@@ -31,18 +31,21 @@ import fire
from ktransformers.optimize.optimize import optimize_and_load_gguf from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.util.utils import prefill_and_generate from ktransformers.util.utils import prefill_and_generate
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
custom_models = { custom_models = {
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM, "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
"Qwen2MoeForCausalLM": Qwen2MoeForCausalLM, "Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
"MixtralForCausalLM": MixtralForCausalLM,
} }
ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/" ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
default_optimize_rules ={ default_optimize_rules ={
"DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml", "DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml",
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml", "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml",
"MixtralForCausalLM": ktransformer_rules_dir + "Mixtral.yaml",
} }
def local_chat( def local_chat(
...@@ -50,7 +53,8 @@ def local_chat( ...@@ -50,7 +53,8 @@ def local_chat(
optimize_rule_path: str = None, optimize_rule_path: str = None,
gguf_path: str = None, gguf_path: str = None,
max_new_tokens: int = 1000, max_new_tokens: int = 1000,
cpu_infer: int = Config().cpu_infer cpu_infer: int = Config().cpu_infer,
use_cuda_graph: bool = True,
): ):
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -64,6 +68,8 @@ def local_chat( ...@@ -64,6 +68,8 @@ def local_chat(
print("using custom modeling_xxx.py.") print("using custom modeling_xxx.py.")
if "Qwen2Moe" in config.architectures[0]: # Qwen2Moe must use flash_attention_2 to avoid overflow. if "Qwen2Moe" in config.architectures[0]: # Qwen2Moe must use flash_attention_2 to avoid overflow.
config._attn_implementation = "flash_attention_2" config._attn_implementation = "flash_attention_2"
if "Mixtral" in config.architectures[0]:
config._attn_implementation = "flash_attention_2"
model = custom_models[config.architectures[0]](config) model = custom_models[config.architectures[0]](config)
else: else:
model = AutoModelForCausalLM.from_config( model = AutoModelForCausalLM.from_config(
...@@ -100,7 +106,6 @@ def local_chat( ...@@ -100,7 +106,6 @@ def local_chat(
while True: while True:
content = input("Chat: ") content = input("Chat: ")
# if content is num
if content == "": if content == "":
content = "Please write a piece of quicksort code in C++." content = "Please write a piece of quicksort code in C++."
...@@ -109,7 +114,7 @@ def local_chat( ...@@ -109,7 +114,7 @@ def local_chat(
messages, add_generation_prompt=True, return_tensors="pt" messages, add_generation_prompt=True, return_tensors="pt"
) )
torch.set_default_dtype(torch.bfloat16) # TODO: Remove this, replace dtype using config torch.set_default_dtype(torch.bfloat16) # TODO: Remove this, replace dtype using config
generated = prefill_and_generate(model, tokenizer, input_tensor.cuda(), max_new_tokens) generated = prefill_and_generate(model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph)
if __name__ == "__main__": if __name__ == "__main__":
fire.Fire(local_chat) fire.Fire(local_chat)
\ No newline at end of file
...@@ -22,13 +22,14 @@ class StaticCache(transformers.StaticCache): ...@@ -22,13 +22,14 @@ class StaticCache(transformers.StaticCache):
The maximum batch size with which the model will be used. The maximum batch size with which the model will be used.
max_cache_len (`int`): max_cache_len (`int`):
The maximum sequence length with which the model will be used. The maximum sequence length with which the model will be used.
device (`torch.device`): device (`torch.device` or `dict`):
The device on which the cache should be initialized. Should be the same as the layer. The device on which the cache should be initialized. Should be the same as the layer.
If a `dict`, it should contain the `device` key with the device name as the value.
dtype (*optional*, defaults to `torch.float32`): dtype (*optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer. The default `dtype` to use when initializing the layer.
""" """
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device: torch.device| dict, dtype=None) -> None:
Cache.__init__(self) Cache.__init__(self)
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
...@@ -57,11 +58,15 @@ class StaticCache(transformers.StaticCache): ...@@ -57,11 +58,15 @@ class StaticCache(transformers.StaticCache):
self.past_tokens = [] self.past_tokens = []
self.num_hidden_layers = config.num_hidden_layers self.num_hidden_layers = config.num_hidden_layers
for _ in range(self.num_hidden_layers): for idx in range(self.num_hidden_layers):
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. # breaks when updating the cache.
new_layer_key_cache = torch.zeros(key_shape, dtype=self.dtype, device=device) if isinstance(device, dict):
new_layer_value_cache = torch.zeros(value_shape, dtype=self.dtype, device=device) target_device = device[f"blk.{idx}.self_attn"]["generate_device"]
else:
target_device = device
new_layer_key_cache = torch.zeros(key_shape, dtype=self.dtype, device=target_device)
new_layer_value_cache = torch.zeros(value_shape, dtype=self.dtype, device=target_device)
torch._dynamo.mark_static_address(new_layer_key_cache) torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache) torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache) self.key_cache.append(new_layer_key_cache)
......
...@@ -1048,7 +1048,7 @@ class DeepseekV2FlashAttention2(DeepseekV2Attention): ...@@ -1048,7 +1048,7 @@ class DeepseekV2FlashAttention2(DeepseekV2Attention):
""" """
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores. first unpad the input, then computes the attention scores and pad the final attention scores.
Args: # Args:
query_states (`torch.Tensor`): query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`): key_states (`torch.Tensor`):
...@@ -1245,12 +1245,14 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1245,12 +1245,14 @@ class DeepseekV2DecoderLayer(nn.Module):
cache_position=cache_position, cache_position=cache_position,
**kwargs, **kwargs,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
# Fully Connected # Fully Connected
residual = hidden_states residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
outputs = (hidden_states,) outputs = (hidden_states,)
......
This diff is collapsed.
...@@ -10,6 +10,7 @@ from ktransformers.operators.base_operator import BaseInjectedModule ...@@ -10,6 +10,7 @@ from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.util.utils import InferenceState from ktransformers.util.utils import InferenceState
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe
class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding): class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
def __init__(self, def __init__(self,
...@@ -17,12 +18,16 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding): ...@@ -17,12 +18,16 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
gguf_loader : GGUFLoader, gguf_loader : GGUFLoader,
config: PretrainedConfig, config: PretrainedConfig,
orig_module: nn.Module, orig_module: nn.Module,
device: str = "cuda", # device: str = "cuda",
generate_device: str = "cuda",
prefill_device: str = "cuda",
**kwargs): **kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
self.orig_module.__init__(orig_module.dim, self.orig_module.__init__(orig_module.dim,
orig_module.max_position_embeddings, orig_module.max_position_embeddings,
orig_module.base) orig_module.base)
self.generate_device = generate_device
self.prefill_device = prefill_device
def load(self): def load(self):
self.orig_module.__init__(self.orig_module.dim, self.orig_module.__init__(self.orig_module.dim,
...@@ -36,9 +41,11 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding): ...@@ -36,9 +41,11 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
gguf_loader : GGUFLoader, gguf_loader : GGUFLoader,
config: PretrainedConfig, config: PretrainedConfig,
orig_module: nn.Module, orig_module: nn.Module,
device: str = "cuda", # device: str = "cuda",
generate_device: str = "cuda",
prefill_device: str = "cuda",
**kwargs): **kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
self.orig_module.__init__(orig_module.dim, self.orig_module.__init__(orig_module.dim,
orig_module.max_position_embeddings, orig_module.max_position_embeddings,
orig_module.base, orig_module.base,
...@@ -49,13 +56,15 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding): ...@@ -49,13 +56,15 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
orig_module.beta_slow, orig_module.beta_slow,
orig_module.mscale, orig_module.mscale,
orig_module.mscale_all_dim) orig_module.mscale_all_dim)
self.generate_device = generate_device
self.prefill_device = prefill_device
def load(self): def load(self):
self.orig_module.__init__(self.orig_module.dim, self.orig_module.__init__(self.orig_module.dim,
self.orig_module.max_position_embeddings, self.orig_module.max_position_embeddings,
self.orig_module.base, self.orig_module.base,
self.device, self.generate_device,
self.orig_module.scaling_factor, self.orig_module.scaling_factor,
self.orig_module.original_max_position_embeddings, self.orig_module.original_max_position_embeddings,
self.orig_module.beta_fast, self.orig_module.beta_fast,
......
...@@ -15,7 +15,7 @@ from ktransformers.util.custom_gguf import GGUFLoader ...@@ -15,7 +15,7 @@ from ktransformers.util.custom_gguf import GGUFLoader
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
class DeepseekV2AttentionInjected(BaseInjectedModule, DeepseekV2Attention): class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, def __init__(self,
......
import sys, os
from typing import Any
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug"))
import cpuinfer_ext
from ktransformers.server.config.config import Config
class CPUInfer:
cpu_infer = None
def __init__(self, cpu_infer:int = Config().cpu_infer):
if CPUInfer.cpu_infer is None:
CPUInfer.cpu_infer = cpuinfer_ext.CPUInfer(cpu_infer)
def __getattribute__(self, __name: str) -> Any:
return CPUInfer.cpu_infer.__getattribute__(__name)
def __setattr__(self, __name: str, __value: Any) -> None:
return CPUInfer.cpu_infer.__setattr__(__name, __value)
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
...@@ -6,7 +6,7 @@ Author : Azure-Tang ...@@ -6,7 +6,7 @@ Author : Azure-Tang
Date : 2024-07-25 11:25:24 Date : 2024-07-25 11:25:24
Version : 1.0.0 Version : 1.0.0
LastEditors : Azure LastEditors : Azure
LastEditTime : 2024-07-26 09:27:48 LastEditTime : 2024-08-14 14:53:05
Copyright (c) 2024 by KVCache.AI, All Rights Reserved. Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
''' '''
...@@ -45,6 +45,8 @@ from ktransformers.models.modeling_deepseek import BaseModelOutputWithPast, Deep ...@@ -45,6 +45,8 @@ from ktransformers.models.modeling_deepseek import BaseModelOutputWithPast, Deep
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.utils import InferenceState from ktransformers.util.utils import InferenceState
from ktransformers.util.custom_gguf import GGUFLoader
from transformers.configuration_utils import PretrainedConfig
if is_flash_attn_2_available(): if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn import flash_attn_func, flash_attn_varlen_func
...@@ -73,34 +75,6 @@ QWEN2MOE_START_DOCSTRING = r""" ...@@ -73,34 +75,6 @@ QWEN2MOE_START_DOCSTRING = r"""
[`~PreTrainedModel.from_pretrained`] method to load the model weights. [`~PreTrainedModel.from_pretrained`] method to load the model weights.
""" """
@add_start_docstrings(
"The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.",
QWEN2MOE_START_DOCSTRING,
)
class Qwen2MoePreTrainedModel(PreTrainedModel):
config_class = Qwen2MoeConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Qwen2MoeDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
QWEN2MOE_INPUTS_DOCSTRING = r""" QWEN2MOE_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
...@@ -177,13 +151,11 @@ QWEN2MOE_INPUTS_DOCSTRING = r""" ...@@ -177,13 +151,11 @@ QWEN2MOE_INPUTS_DOCSTRING = r"""
the complete sequence length. the complete sequence length.
""" """
from ktransformers.util.custom_gguf import GGUFLoader
from transformers.configuration_utils import PretrainedConfig
@add_start_docstrings( @add_start_docstrings(
"The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.", "The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.",
QWEN2MOE_START_DOCSTRING, QWEN2MOE_START_DOCSTRING,
) )
class Qwen2MoeModelPerLayerPrefill(BaseInjectedModule): class KQwen2MoeModel(BaseInjectedModule):
""" """
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2MoeDecoderLayer`] Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2MoeDecoderLayer`]
...@@ -198,10 +170,13 @@ class Qwen2MoeModelPerLayerPrefill(BaseInjectedModule): ...@@ -198,10 +170,13 @@ class Qwen2MoeModelPerLayerPrefill(BaseInjectedModule):
orig_module: nn.Module, orig_module: nn.Module,
device: str = "cuda", device: str = "cuda",
per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill
transfer_map: dict = None,
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold
self.transfer_map = transfer_map
self.stream_device_map = dict()
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
def forward( def forward(
...@@ -287,7 +262,20 @@ class Qwen2MoeModelPerLayerPrefill(BaseInjectedModule): ...@@ -287,7 +262,20 @@ class Qwen2MoeModelPerLayerPrefill(BaseInjectedModule):
all_router_logits = () if output_router_logits else None all_router_logits = () if output_router_logits else None
next_decoder_cache = None next_decoder_cache = None
for decoder_layer in self.layers: for i, decoder_layer in enumerate(self.layers):
if self.transfer_map is not None and i in self.transfer_map:
prev_stream = torch.cuda.current_stream()
cur_device = self.transfer_map[i]
if cur_device not in self.stream_device_map:
self.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)
torch.cuda.set_device(cur_device)
self.stream_device_map[cur_device].wait_stream(prev_stream)
torch.cuda.set_stream(self.stream_device_map[cur_device])
hidden_states = hidden_states.to(self.transfer_map[i], non_blocking = True)
causal_mask = causal_mask.to(self.transfer_map[i], non_blocking = True) if causal_mask is not None else None
position_ids = position_ids.to(self.transfer_map[i], non_blocking = True) if position_ids is not None else None
cache_position = cache_position.to(self.transfer_map[i], non_blocking = True) if cache_position is not None else None
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
...@@ -463,7 +451,7 @@ DeepseekV2_INPUTS_DOCSTRING = r""" ...@@ -463,7 +451,7 @@ DeepseekV2_INPUTS_DOCSTRING = r"""
""" """
class DeepseekV2ModelPerLayerPrefill(BaseInjectedModule): class KDeepseekV2Model(BaseInjectedModule):
""" """
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`] Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`]
...@@ -478,10 +466,13 @@ class DeepseekV2ModelPerLayerPrefill(BaseInjectedModule): ...@@ -478,10 +466,13 @@ class DeepseekV2ModelPerLayerPrefill(BaseInjectedModule):
orig_module: nn.Module, orig_module: nn.Module,
device: str = "cuda", device: str = "cuda",
per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill
transfer_map: dict = None,
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold
self.transfer_map = transfer_map
self.stream_device_map = dict()
@add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
def forward( def forward(
...@@ -584,7 +575,20 @@ class DeepseekV2ModelPerLayerPrefill(BaseInjectedModule): ...@@ -584,7 +575,20 @@ class DeepseekV2ModelPerLayerPrefill(BaseInjectedModule):
t_cpu = 0 t_cpu = 0
t_f = 0 t_f = 0
for decoder_layer in self.layers: for i, decoder_layer in enumerate(self.layers):
if self.transfer_map is not None and i in self.transfer_map:
prev_stream = torch.cuda.current_stream()
cur_device = self.transfer_map[i]
if cur_device not in self.stream_device_map:
self.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)
torch.cuda.set_device(cur_device)
self.stream_device_map[cur_device].wait_stream(prev_stream)
torch.cuda.set_stream(self.stream_device_map[cur_device])
hidden_states = hidden_states.to(self.transfer_map[i], non_blocking = True)
causal_mask = causal_mask.to(self.transfer_map[i], non_blocking = True) if causal_mask is not None else None
position_ids = position_ids.to(self.transfer_map[i], non_blocking = True) if position_ids is not None else None
cache_position = cache_position.to(self.transfer_map[i], non_blocking = True) if cache_position is not None else None
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment