Unverified Commit 7527619f authored by UnicornChan's avatar UnicornChan Committed by GitHub
Browse files

Merge pull request #122 from kvcache-ai/feat-DeepSeekV3

[Feat] add support to DeepSeekV3
parents f4903d54 6f0fe953
...@@ -142,11 +142,11 @@ jobs: ...@@ -142,11 +142,11 @@ jobs:
- name: Setup Mamba - name: Setup Mamba
if: matrix.cuda != '' if: matrix.cuda != ''
uses: conda-incubator/setup-miniconda@v2.3.0 uses: conda-incubator/setup-miniconda@v3
with: with:
activate-environment: "ktransformers" activate-environment: "ktransformers"
python-version: ${{ matrix.pyver }} python-version: ${{ matrix.pyver }}
miniforge-variant: Mambaforge miniforge-variant: Miniforge3
miniforge-version: latest miniforge-version: latest
use-mamba: true use-mamba: true
add-pip-as-python-dependency: true add-pip-as-python-dependency: true
......
...@@ -54,11 +54,11 @@ jobs: ...@@ -54,11 +54,11 @@ jobs:
- name: Setup Mamba - name: Setup Mamba
if: matrix.cuda != '' if: matrix.cuda != ''
uses: conda-incubator/setup-miniconda@v2.3.0 uses: conda-incubator/setup-miniconda@v3
with: with:
activate-environment: "ktransformers" activate-environment: "ktransformers"
python-version: ${{ matrix.pyver }} python-version: ${{ matrix.pyver }}
miniforge-variant: Mambaforge miniforge-variant: Miniforge3
miniforge-version: latest miniforge-version: latest
use-mamba: true use-mamba: true
add-pip-as-python-dependency: true add-pip-as-python-dependency: true
......
...@@ -18,4 +18,7 @@ compile_commands.json ...@@ -18,4 +18,7 @@ compile_commands.json
ktransformers/server/local_store/ ktransformers/server/local_store/
ktransformers/server_test1.db ktransformers/server_test1.db
*.patch *.patch
img/ img/
\ No newline at end of file tmp1.txt
test_65_300_1536.txt
test.txt
...@@ -17,5 +17,5 @@ dev_install: ...@@ -17,5 +17,5 @@ dev_install:
pip install -r requirements-local_chat.txt pip install -r requirements-local_chat.txt
echo "Installing ktransformers" echo "Installing ktransformers"
KTRANSFORMERS_FORCE_BUILD=TRUE pip install -e . --no-build-isolation KTRANSFORMERS_FORCE_BUILD=TRUE pip install -e . -v --no-build-isolation
echo "Installation completed successfully" echo "Installation completed successfully"
\ No newline at end of file
...@@ -23,6 +23,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin ...@@ -23,6 +23,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
<h2 id="Updates">🔥 Updates</h2> <h2 id="Updates">🔥 Updates</h2>
* **Fed 10, 2025**: Support Deepseek-R1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~64x speedup. The detailed tutorial is [here](./doc/en/DeepseekR1_V3_tutorial.md)
* **Aug 28, 2024**: Support 1M context under the InternLM2.5-7B-Chat-1M model, utilizing 24GB of VRAM and 150GB of DRAM. The detailed tutorial is [here](./doc/en/long_context_tutorial.md). * **Aug 28, 2024**: Support 1M context under the InternLM2.5-7B-Chat-1M model, utilizing 24GB of VRAM and 150GB of DRAM. The detailed tutorial is [here](./doc/en/long_context_tutorial.md).
* **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G. * **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G.
* **Aug 15, 2024**: Update detailed [TUTORIAL](doc/en/injection_tutorial.md) for injection and multi-GPU. * **Aug 15, 2024**: Update detailed [TUTORIAL](doc/en/injection_tutorial.md) for injection and multi-GPU.
...@@ -30,54 +31,68 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin ...@@ -30,54 +31,68 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
* **Aug 12, 2024**: Support multiple GPU; Support new model: mixtral 8\*7B and 8\*22B; Support q2k, q3k, q5k dequant on gpu. * **Aug 12, 2024**: Support multiple GPU; Support new model: mixtral 8\*7B and 8\*22B; Support q2k, q3k, q5k dequant on gpu.
* **Aug 9, 2024**: Support windows native. * **Aug 9, 2024**: Support windows native.
<h2 id="show-cases">🔥 Show Cases</h2> <h2 id="show-cases">🌟 Show Cases</h2>
<h3>1M Context Local Inference on a Desktop with Only 24GB VRAM</h3>
<p align="center">
https://github.com/user-attachments/assets/a865e5e4-bca3-401e-94b8-af3c080e6c12 <div>
<h3>GPT-4/o1-level Local VSCode Copilot on a Desktop with only 24GB VRAM</h3>
</div>
* **1M Context InternLM 2.5 7B**: Operates at full bf16 precision, utilizing 24GB VRAM and 150GB DRAM, which is feasible on a local desktop setup. It achieves a 92.88% success rate on the 1M "Needle In a Haystack" test and 100% on the 128K NIAH test. https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285
<p align="center">
<picture>
<img alt="Single Needle Retrieval 128K" src="./doc/assets/needle_128K.png" width=100%>
</picture>
</p> </p>
- **[NEW!!!] Local 671B DeepSeek-Coder-V3/R1:** Running its Q4_K_M version using only 14GB VRAM and 382GB DRAM.
- Prefill Speed (tokens/s):
- KTransfermor: 54.21 (32 cores) → 74.362 (dual-socket, 2×32 cores) → 255.26 (optimized AMX-based MoE kernel, V0.3 only) → 286.55 (selectively using 6 experts, V0.3 only)
- Compared to 4.51 tokens/s in llama.cpp with 2×32 cores, achieving up to **63.53× speedup**.
- Decode Speed (tokens/s):
- KTransfermor: 8.73 (32 cores) → 11.26 (dual-socket, 2×32 cores) → 13.69 (selectively using 6 experts, V0.3 only)
- Compared to 4.51 tokens/s in llama.cpp with 2×32 cores, achieving up to **3.03× speedup**.
- Upcoming Open Source Release:
- AMX optimizations and selective expert activation will be open-sourced in V0.3.
- Currently available only in preview binary distribution, which can be downloaded [here](https://github.com/kvcache-ai/ktransformers/releases/download/v0.1.4/ktransformers-0.3.0rc0+cu126torch26fancy-cp311-cp311-linux_x86_64.whl).
- **Local 236B DeepSeek-Coder-V2:** Running its Q4_K_M version using only 21GB VRAM and 136GB DRAM, attainable on a local desktop machine, which scores even better than GPT4-0613 in [BigCodeBench](https://huggingface.co/blog/leaderboard-bigcodebench).
<p align="center"> <p align="center">
<picture> <picture>
<img alt="Single Needle Retrieval 1000K" src="./doc/assets/needle_1M.png" width=100%> <img alt="DeepSeek-Coder-V2 Score" src="https://github.com/user-attachments/assets/d052924e-8631-44de-aad2-97c54b965693" width=100%>
</picture> </picture>
</p> </p>
* **Enhanced Speed**: Reaches 16.91 tokens/s for generation with a 1M context using sparse attention, powered by llamafile kernels. This method is over 10 times faster than full attention approach of llama.cpp. - **Faster Speed:** Achieving 126 tokens/s for 2K prompt prefill and 13.6 tokens/s for generation through MoE offloading and injecting advanced kernels from [Llamafile](https://github.com/Mozilla-Ocho/llamafile/tree/main) and [Marlin](https://github.com/IST-DASLab/marlin).
- **VSCode Integration:** Wrapped into an OpenAI and Ollama compatible API for seamless integration as a backend for [Tabby](https://github.com/TabbyML/tabby) and various other frontends.
* **Flexible Sparse Attention Framework**: Offers a flexible block sparse attention framework for CPU offloaded decoding. Compatible with SnapKV, Quest, and InfLLm. Further information is available [here](./doc/en/long_context_introduction.md).
<div> <p align="center">
<h3>GPT-4-level Local VSCode Copilot on a Desktop with only 24GB VRAM</h3>
</div>
https://github.com/user-attachments/assets/0b9fa2da-66f0-48eb-b4b9-f0e1f06f8927 https://github.com/user-attachments/assets/4c6a8a38-05aa-497d-8eb1-3a5b3918429c
</p> </p>
- **Local 236B DeepSeek-Coder-V2:** Running its Q4_K_M version using only 21GB VRAM and 136GB DRAM, attainable on a local desktop machine, which scores even better than GPT4-0613 in [BigCodeBench](https://huggingface.co/blog/leaderboard-bigcodebench). <h3>1M Context Local Inference on a Desktop with Only 24GB VRAM</h3>
<p align="center">
https://github.com/user-attachments/assets/a865e5e4-bca3-401e-94b8-af3c080e6c12
* **1M Context InternLM 2.5 7B**: Operates at full bf16 precision, utilizing 24GB VRAM and 150GB DRAM, which is feasible on a local desktop setup. It achieves a 92.88% success rate on the 1M "Needle In a Haystack" test and 100% on the 128K NIAH test.
<p align="center"> <p align="center">
<picture> <picture>
<img alt="DeepSeek-Coder-V2 Score" src="https://github.com/user-attachments/assets/d052924e-8631-44de-aad2-97c54b965693" width=100%> <img alt="Single Needle Retrieval 128K" src="./doc/assets/needle_128K.png" width=100%>
</picture> </picture>
</p> </p>
- **Faster Speed:** Achieving 126 tokens/s for 2K prompt prefill and 13.6 tokens/s for generation through MoE offloading and injecting advanced kernels from [Llamafile](https://github.com/Mozilla-Ocho/llamafile/tree/main) and [Marlin](https://github.com/IST-DASLab/marlin).
- **VSCode Integration:** Wrapped into an OpenAI and Ollama compatible API for seamless integration as a backend for [Tabby](https://github.com/TabbyML/tabby) and various other frontends.
<p align="center"> <p align="center">
<picture>
<img alt="Single Needle Retrieval 1000K" src="./doc/assets/needle_1M.png" width=100%>
</picture>
</p>
* **Enhanced Speed**: Reaches 16.91 tokens/s for generation with a 1M context using sparse attention, powered by llamafile kernels. This method is over 10 times faster than full attention approach of llama.cpp.
* **Flexible Sparse Attention Framework**: Offers a flexible block sparse attention framework for CPU offloaded decoding. Compatible with SnapKV, Quest, and InfLLm. Further information is available [here](./doc/en/long_context_introduction.md).
https://github.com/user-attachments/assets/4c6a8a38-05aa-497d-8eb1-3a5b3918429c
</p>
<strong>More advanced features will coming soon, so stay tuned!</strong> <strong>More advanced features will coming soon, so stay tuned!</strong>
......
# GPT-4/o1-level Local VSCode Copilot on a Desktop with only 24GB VRAM
# SUMMARY
> **Fed 10, 2025**: Support DeepseekR1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~64x speedup.<br>
Hi, we're the KTransformers team (formerly known for our local CPU/GPU hybrid inference open source project with DeepSeek-V2).
We've heard your requests for DeepSeek-R1/V3 support—and we're excited to finally deliver!
Apologies for the wait, but we've been cooking up something truly amazing!
Today, we're proud to announce that we not only support DeepSeek-R1/V3, as showcased in the video below:
https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285
</p>
- **[NEW!!!] Local 671B DeepSeek-Coder-V3/R1:** Running its Q4_K_M version using only 14GB VRAM and 382GB DRAM.
- Prefill Speed (tokens/s):
- KTransfermor: 54.21 (32 cores) → 74.362 (dual-socket, 2×32 cores) → 255.26 (optimized AMX-based MoE kernel, V0.3 only) → 286.55 (selectively using 6 experts, V0.3 only)
- Compared to 4.51 tokens/s in llama.cpp with 2×32 cores, achieving up to **63.53× speedup**.
- Decode Speed (tokens/s):
- KTransfermor: 8.73 (32 cores) → 11.26 (dual-socket, 2×32 cores) → 13.69 (selectively using 6 experts, V0.3 only)
- Compared to 4.51 tokens/s in llama.cpp with 2×32 cores, achieving up to **3.03× speedup**.
We also give our upcoming optimizations previews, including an Intel AMX-accelerated kernel and a selective expert activation method, which will significantly enhance performance. With V0.3-preview, we achieve up to 286 tokens/s for prefill, making it up to **64× faster than llama.cpp** for local inference.
The binary distribution is available now and the source code will come ASAP! Check out the wheel package [here](https://github.com/kvcache-ai/ktransformers/releases/download/v0.1.4/ktransformers-0.3.0rc0+cu126torch26fancy-cp311-cp311-linux_x86_64.whl)
## Prerequisites
We run our best performance tests (V0.2) on <br>
CPU: Intel (R) Xeon (R) Gold 6454S 1T DRAM (2 NUMA nodes) <br>
GPU: 4090D 24G VRAM <br>
## Bench Result
### V0.2
#### Settings
- Model: DeepseekV3-q4km (int4)<br>
- CPU: cpu_model_name: Intel (R) Xeon (R) Gold 6454S, 32 cores per socket, 2 sockets, 2 numa nodes
- GPU: 4090D 24G VRAM
- We test after enough warm up
#### Memory consumption:
- Single socket: 382G DRAM, at least 14GB VRAM
- Dual socket: 1T DRAM, at least 14GB VRAM
#### Benchmark Results
"6 experts" case is part of V0.3's preview
| Prompt<br>(500 tokens) | Dual socket Ktrans (6 experts) | Dual socket Ktrans (8 experts) | Single socket Ktrans (6 experts) | Single socket Ktrans (8 experts)| llama.cpp (8 experts) |
| --- | --- | --- | --- | --- | --- |
| Prefill token/s | 97.32 | 82.94 | 65.14 | 54.21 | 10.31 |
| Decode token/s | 13.69 | 12.208 | 10.303 | 8.73 |4.51 |
**The highest speedup reaches up to <u>3.03x</u> in decoding and <u>9.44x</u> in prefill.**
### V0.3-Preview
#### Settings
- Model: DeepseekV3-BF16 (online quant into int8 for CPU and int4 for GPU)
- CPU: cpu_model_name: Intel (R) Xeon (R) Gold 6454S, 32 cores per socket, 2 socket, 2 numa nodes
- GPU: (1~4)x 4090D 24GVRAM (requires more VRAM for longer prompt)
#### Memory consumptions:
- 644GB DRAM, at least 14GB VRAM
#### Benchmark results
| Prompt length | 1K | 2K | 4K | 8K |
|---------------|-----|-----|-----|-----|
| KTrans (8 experts) Prefill token/s | 185.96 | 255.26 | 252.58 | 195.62 |
| KTrans (6 experts) Prefill token/s | 203.70 | 286.55 | 271.08 | 207.20 |
**The prefill of KTrans V0.3 is up to <u>3.45x</u> times faster than KTrans V0.2, and is up to <u>63.53x</u> times faster than llama.cpp.**
**The decoding speed is the same as KTrans V0.2 (6 experts version) so it is omitted**
The main acceleration comes from
- Intel AMX instruction set and our specially designed cache friendly memory layout
- Expert selection strategy that selects fewer experts based on offline profile results of out of domain data
*From our research on DeepSeekV2, DeepSeekV3 and DeepSeekR1,
when we slightly decrease the activation experts num in inference,
the output quality doesn't change. But the speed of decoding and prefill
is speed up which is inspiring. So our showcase makes use of this finding*
## How to Run
### V0.2 Showcase
#### Single socket version (32 cores)
Our local_chat test command is:
``` shell
git clone https://github.com/kvcache-ai/ktransformers.git
cd ktransformers
numactl -N 1 -m 1 python ./ktransformers/local_chat.py --model_path <your model path> --gguf_path <your gguf path> --prompt_file <your prompt txt file> --cpu_infer 33 --cache_lens 1536
<when you see chat, then press enter to load the text prompt_file>
```
\<your model path\> can be local or set from online hugging face like deepseek-ai/DeepSeek-V3. If online encounters connection problem, try use mirror (hf-mirror.com) <br>
\<your gguf path\> can also be online, but as its large we recommend you download it and quantize the model to what you want <br>
The command numactl -N 1 -m 1 aims to advoid data transfer between numa nodes
#### Dual socket version (64 cores)
Make suer before you install (use install.sh or `make dev_install`), setting the env var `USE_NUMA=1` by `export USE_NUMA=1` (if already installed, reinstall it with this env var set) <br>
Our local_chat test command is:
``` shell
git clone https://github.com/kvcache-ai/ktransformers.git
cd ktransformers
export USE_NUMA=1
make dev_install # or sh ./install.sh
python ./ktransformers/local_chat.py --model_path <your model path> --gguf_path <your gguf path> --prompt_file <your prompt txt file> --cpu_infer 65 --cache_lens 1536
<when you see chat, then press enter to load the text prompt_file>
```
The parameters' meaning is the same. But As we use dual socket, we set cpu_infer to 65
### V0.3 Showcase
#### Dual socket version (64 cores)
Our local_chat test command is:
``` shell
wget https://github.com/kvcache-ai/ktransformers/releases/download/v0.1.4/ktransformers-0.3.0rc0+cu126torch26fancy-cp311-cp311-linux_x86_64.whl
pip install ./ktransformers-0.3.0rc0+cu126torch26fancy-cp311-cp311-linux_x86_64.whl
python -m ktransformers.local_chat --model_path <your model path> --gguf_path <your gguf path> --prompt_file <your prompt txt file> --cpu_infer 65 --cache_lens 1536
<when you see chat, then press enter to load the text prompt_file>
```
The parameters' meaning is the same with V0.2. But As we use dual socket, we set cpu_infer to 65
## Some Explanations
1. Also we want to make further use of our two NUMA nodes on Xeon Gold cpu.
To avoid the cost of data transfer between nodes, we "copy" the critical matrix on
both nodes which takes more memory consumption but accelerates the prefill and decoding process.
But this method takes huge memory and slow when loading weights, So be patient when loading
and monitor the memory usage. We are going to optimize this huge memory overhead. Stay tuned~ <br>
2. The command args `--cpu_infer 65` specifies how many cores to use (it's ok that it exceeds the physical number,
but it's not the more the better. Adjust it slightly lower to your actual number of cores)<br>
3. Why CPU/GPU Hybrid Inference?
DeepSeek's MLA operators are highly computationally intensive. While running everything on CPU is possible, offloading the heavy computations to the GPU results in a massive performance boost.
4. Where Does the Speedup Come From?
- Expert Offload: Unlike traditional layer-based or KVCache offloading (as seen in llama.cpp), we offload the expert computation to the CPU and MLA/KVCache to GPU, aligning perfectly with DeepSeek’s architecture for optimal efficiency.
- Intel AMX Optimization – Our AMX-accelerated kernel is meticulously tuned, running several times faster than existing llama.cpp implementations. We plan to open-source this kernel after cleansing and are considering upstream contributions to llama.cpp.
5. Why Intel CPUs?
Intel is currently the only CPU vendor that supports AMX-like instructions, which delivers significantly better performance compared to AVX-only alternatives.
\ No newline at end of file
...@@ -5,7 +5,7 @@ Description : ...@@ -5,7 +5,7 @@ Description :
Author : kkk1nak0 Author : kkk1nak0
Date : 2024-08-15 07:34:46 Date : 2024-08-15 07:34:46
Version : 1.0.0 Version : 1.0.0
LastEditors : Azure-Tang LastEditors : unicornchan
LastEditTime : 2024-08-29 22:35:51 LastEditTime : 2025-02-10 00:59:53
''' '''
__version__ = "0.1.4" __version__ = "0.2.0"
\ No newline at end of file \ No newline at end of file
...@@ -54,4 +54,4 @@ long_context: ...@@ -54,4 +54,4 @@ long_context:
token_step: token_step:
local_chat: local_chat:
prompt_file: "./ktransformers/p.txt" prompt_file: ""
\ No newline at end of file \ No newline at end of file
...@@ -230,3 +230,24 @@ elseif(UNIX) ...@@ -230,3 +230,24 @@ elseif(UNIX)
endif() endif()
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so") target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so")
endif() endif()
# Define the USE_NUMA option
option(USE_NUMA "Disable NUMA support" OFF)
# Check if the USE_NUMA environment variable is set
if(DEFINED ENV{USE_NUMA})
set(USE_NUMA ON)
endif()
if (USE_NUMA)
message(STATUS "NUMA support is enabled")
else()
message(STATUS "NUMA support is disabled")
endif()
find_library(NUMA_LIBRARY NAMES numa)
if (NUMA_LIBRARY AND USE_NUMA)
message(STATUS "NUMA library found: ${NUMA_LIBRARY} - enabling NUMA support")
target_link_libraries(${PROJECT_NAME} PRIVATE ${NUMA_LIBRARY})
target_compile_definitions(${PROJECT_NAME} PRIVATE USE_NUMA)
else()
message(STATUS "NUMA library not found or user not set USE_NUMA - disabling NUMA support")
endif()
...@@ -10,6 +10,13 @@ ...@@ -10,6 +10,13 @@
#include "backend.h" #include "backend.h"
#ifdef USE_NUMA
#include <numa.h>
#include <numaif.h>
thread_local int Backend::numa_node = -1;
#endif
thread_local int Backend::thread_local_id = -1; thread_local int Backend::thread_local_id = -1;
Backend::Backend(int max_thread_num) { Backend::Backend(int max_thread_num) {
...@@ -74,6 +81,16 @@ void Backend::do_work_stealing_job(int task_num, ...@@ -74,6 +81,16 @@ void Backend::do_work_stealing_job(int task_num,
} }
void Backend::process_tasks(int thread_id) { void Backend::process_tasks(int thread_id) {
#ifdef USE_NUMA
if(numa_node == -1){
numa_node = thread_id * numa_num_configured_nodes() / thread_num_;
struct bitmask* mask = numa_bitmask_alloc(numa_num_configured_nodes());
numa_bitmask_setbit(mask, numa_node);
numa_bind(mask);
}
#endif
if (init_func_ != nullptr) { if (init_func_ != nullptr) {
init_func_(thread_id); init_func_(thread_id);
} }
......
...@@ -38,6 +38,9 @@ class Backend { ...@@ -38,6 +38,9 @@ class Backend {
void do_work_stealing_job(int, std::function<void(int)>, void do_work_stealing_job(int, std::function<void(int)>,
std::function<void(int)>, std::function<void(int)>,
std::function<void(int)>); std::function<void(int)>);
#ifdef USE_NUMA
static thread_local int numa_node;
#endif
static thread_local int thread_local_id; static thread_local int thread_local_id;
private: private:
......
...@@ -11,11 +11,41 @@ ...@@ -11,11 +11,41 @@
#include <iostream> #include <iostream>
#include <cstdint> #include <cstdint>
#ifdef USE_NUMA
#include <numa.h>
#include <numaif.h>
#endif
MOE::MOE(MOEConfig config) { MOE::MOE(MOEConfig config) {
config_ = config; config_ = config;
gate_proj_ = config_.gate_proj; gate_proj_ = config_.gate_proj;
up_proj_ = config_.up_proj; up_proj_ = config_.up_proj;
down_proj_ = config_.down_proj; down_proj_ = config_.down_proj;
#ifdef USE_NUMA
int numa_nodes = numa_num_configured_nodes();
gate_proj_numa_.resize(numa_nodes);
up_proj_numa_.resize(numa_nodes);
down_proj_numa_.resize(numa_nodes);
size_t exp_inter_hidden_mul_ = (size_t)config.expert_num * config.intermediate_size * config.hidden_size;
for (int i = 0; i < numa_nodes; i++) {
gate_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.gate_type) / ggml_blck_size(config.gate_type), i);
up_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.up_type) / ggml_blck_size(config.up_type), i);
down_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.down_type) / ggml_blck_size(config.down_type), i);
if (!gate_proj_numa_[i]) {
std::cout << "Memory allocation failed for gate_proj_numa_ on node " << i << std::endl;
}
if (!up_proj_numa_[i]) {
std::cout << "Memory allocation failed for up_proj_numa_ on node " << i << std::endl;
}
if (!down_proj_numa_[i]) {
std::cout << "Memory allocation failed for down_proj_numa_ on node " << i << std::endl;
}
memcpy(gate_proj_numa_[i], gate_proj_, exp_inter_hidden_mul_* ggml_type_size(config.gate_type) / ggml_blck_size(config.gate_type));
memcpy(up_proj_numa_[i], up_proj_, exp_inter_hidden_mul_* ggml_type_size(config.up_type) / ggml_blck_size(config.up_type));
memcpy(down_proj_numa_[i], down_proj_, exp_inter_hidden_mul_* ggml_type_size(config.down_type) / ggml_blck_size(config.down_type));
}
#endif
std::vector<std::pair<void**, uint64_t>> s_mem_requests; std::vector<std::pair<void**, uint64_t>> s_mem_requests;
s_mem_requests.push_back({(void**)&s_input_fp32_, sizeof(float) * config_.hidden_size}); s_mem_requests.push_back({(void**)&s_input_fp32_, sizeof(float) * config_.hidden_size});
...@@ -74,6 +104,15 @@ MOE::MOE(MOEConfig config) { ...@@ -74,6 +104,15 @@ MOE::MOE(MOEConfig config) {
MOE::~MOE() { MOE::~MOE() {
shared_mem_buffer.dealloc(this); shared_mem_buffer.dealloc(this);
#ifdef USE_NUMA
int numa_nodes = numa_num_configured_nodes();
for (int i = 0; i < numa_nodes; i++) {
numa_free(gate_proj_numa_[i], config_.expert_num * config_.intermediate_size * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type));
numa_free(up_proj_numa_[i], config_.expert_num * config_.intermediate_size * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type));
numa_free(down_proj_numa_[i], config_.expert_num * config_.hidden_size * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type));
}
#endif
} }
void MOE::warm_up(Backend* backend) { void MOE::warm_up(Backend* backend) {
...@@ -125,10 +164,22 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c ...@@ -125,10 +164,22 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c
int expert_idx = task_id / nth; int expert_idx = task_id / nth;
uint64_t expert_id = expert_ids[expert_idx]; uint64_t expert_id = expert_ids[expert_idx];
int ith = task_id % nth; int ith = task_id % nth;
#ifdef USE_NUMA
void* gate_proj_ptr = (uint8_t*)gate_proj_numa_[Backend::numa_node] + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
#else
void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
#endif
float* gate_output_ptr = s_gate_output_[expert_idx] + ith * config_.stride; float* gate_output_ptr = s_gate_output_[expert_idx] + ith * config_.stride;
llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
#ifdef USE_NUMA
void* up_proj_ptr = (uint8_t*)up_proj_numa_[Backend::numa_node] + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);
#else
void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);
#endif
float* up_output_ptr = s_up_output_[expert_idx] + ith * config_.stride; float* up_output_ptr = s_up_output_[expert_idx] + ith * config_.stride;
llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) { for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {
...@@ -153,7 +204,13 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c ...@@ -153,7 +204,13 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c
} }
for (int expert_idx = 0; expert_idx < k; expert_idx++) { for (int expert_idx = 0; expert_idx < k; expert_idx++) {
uint64_t expert_id = expert_ids[expert_idx]; uint64_t expert_id = expert_ids[expert_idx];
#ifdef USE_NUMA
void* down_proj_ptr = (uint8_t*)down_proj_numa_[Backend::numa_node] + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
#else
void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
#endif
float* down_output_ptr = s_down_output_[expert_idx] + ith * config_.stride; float* down_output_ptr = s_down_output_[expert_idx] + ith * config_.stride;
llamafile_sgemm(config_.stride, 1, config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), s_down_input_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); llamafile_sgemm(config_.stride, 1, config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), s_down_input_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) { for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {
...@@ -224,14 +281,26 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float* ...@@ -224,14 +281,26 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float*
int stride = QK_K; int stride = QK_K;
int nth = config_.intermediate_size / stride; int nth = config_.intermediate_size / stride;
backend->do_work_stealing_job(nth * config_.expert_num, nullptr, [&](int task_id) { backend->do_work_stealing_job(nth * config_.expert_num, nullptr, [&](int task_id) {
int expert_idx = task_id / nth; uint64_t expert_idx = task_id / nth;
int ith = task_id % nth; int ith = task_id % nth;
void* gate_input_ptr = m_local_gate_input_ptr_[expert_idx]; void* gate_input_ptr = m_local_gate_input_ptr_[expert_idx];
#ifdef USE_NUMA
void* gate_proj_ptr = (uint8_t*)gate_proj_numa_[Backend::numa_node] + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
#else
void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
#endif
float* gate_output_ptr = m_local_gate_output_ptr_[expert_idx] + ith * stride; float* gate_output_ptr = m_local_gate_output_ptr_[expert_idx] + ith * stride;
llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
void* up_input_ptr = m_local_up_input_ptr_[expert_idx]; void* up_input_ptr = m_local_up_input_ptr_[expert_idx];
#ifdef USE_NUMA
void* up_proj_ptr = (uint8_t*)up_proj_numa_[Backend::numa_node] + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);
#else
void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);
#endif
float* up_output_ptr = m_local_up_output_ptr_[expert_idx] + ith * stride; float* up_output_ptr = m_local_up_output_ptr_[expert_idx] + ith * stride;
llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
for (int i = 0; i < m_local_num_[expert_idx]; i++) { for (int i = 0; i < m_local_num_[expert_idx]; i++) {
...@@ -246,10 +315,16 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float* ...@@ -246,10 +315,16 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float*
stride = QK_K; stride = QK_K;
nth = config_.hidden_size / stride; nth = config_.hidden_size / stride;
backend->do_work_stealing_job(nth * config_.expert_num, nullptr, [&](int task_id) { backend->do_work_stealing_job(nth * config_.expert_num, nullptr, [&](int task_id) {
int expert_idx = task_id / nth; uint64_t expert_idx = task_id / nth;
int ith = task_id % nth; int ith = task_id % nth;
void* down_input_ptr = m_local_down_input_ptr_[expert_idx]; void* down_input_ptr = m_local_down_input_ptr_[expert_idx];
#ifdef USE_NUMA
void* down_proj_ptr = (uint8_t*)down_proj_numa_[Backend::numa_node] + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
#else
void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
#endif
float* down_output_ptr = m_local_down_output_ptr_[expert_idx] + ith * stride; float* down_output_ptr = m_local_down_output_ptr_[expert_idx] + ith * stride;
llamafile_sgemm(stride, m_local_num_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); llamafile_sgemm(stride, m_local_num_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
}, nullptr); }, nullptr);
......
...@@ -61,6 +61,12 @@ class MOE { ...@@ -61,6 +61,12 @@ class MOE {
void* up_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)] void* up_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]
void* down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if quantized)] void* down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if quantized)]
#ifdef USE_NUMA
std::vector<void*> gate_proj_numa_; // [numa_num, expert_num * intermediate_size * hidden_size ( /32 if quantized)]
std::vector<void*> up_proj_numa_; // [numa_num, expert_num * intermediate_size * hidden_size ( /32 if quantized)]
std::vector<void*> down_proj_numa_; // [numa_num, expert_num * hidden_size * intermediate_size ( /32 if quantized)]
#endif
float* s_input_fp32_; // [hidden_size] float* s_input_fp32_; // [hidden_size]
uint8_t* s_gate_input_; // [hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)] uint8_t* s_gate_input_; // [hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]
uint8_t* s_up_input_; // [hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)] uint8_t* s_up_input_; // [hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]
......
# """
# Description :
# Author : Boxin Zhang, Azure-Tang
# Version : 0.1.0
# Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
# """
# import asyncio
# import os
# import platform
# import sys
# project_dir = os.path.dirname(os.path.dirname(__file__))
# sys.path.insert(0, project_dir)
# from ktransformers.server.args import ArgumentParser
# from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
# from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
# from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
# from ktransformers.models.modeling_llama import LlamaForCausalLM
# from ktransformers.models.modeling_mixtral import MixtralForCausalLM
# from ktransformers.server.config.config import Config
# custom_models = {
# "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
# "DeepseekV3ForCausalLM": DeepseekV3ForCausalLM,
# "Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
# "LlamaForCausalLM": LlamaForCausalLM,
# "MixtralForCausalLM": MixtralForCausalLM,
# }
# ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
# default_optimize_rules = {
# "DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml",
# "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat.yaml",
# "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml",
# "LlamaForCausalLM": ktransformer_rules_dir + "Internlm2_5-7b-Chat-1m.yaml",
# "MixtralForCausalLM": ktransformer_rules_dir + "Mixtral.yaml",
# }
# def local_chat():
# config = Config()
# arg_parser = ArgumentParser(config)
# # 初始化消息
# arg_parser.parse_args()
# if config.backend_type == "transformers":
# from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface
# elif config.backend_type == "exllamav2":
# from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface
# elif config.backend_type == "ktransformers":
# from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface
# else:
# raise NotImplementedError(f"{config.backend_type} not implemented")
# interface = BackendInterface(config)
# system = platform.system()
# if system == "Windows":
# os.system("cls")
# else:
# os.system("clear")
# # add a history chat content
# his_content = []
# while True:
# content = input("Chat: ")
# if content.startswith('"""'): # prefix """
# # multi lines input
# content = content[3:] + "\n"
# while True:
# line = input("")
# if line.endswith('"""'):
# # end multi lines input
# line = line[:-3] # suffix """
# if line:
# content += line + "\n"
# break
# else:
# content += line + "\n"
# if content == "":
# if not config.prompt_file:
# content = "hi"
# else:
# content = open(config.prompt_file, "r").read()
# print("User: ", content)
# elif os.path.isfile(content):
# content = open(content, "r").read()
# print("User: ", content)
# messages = his_content + [{"role": "user", "content": content}]
# async def async_inference(messages):
# generated = ""
# async for token in interface.inference(messages, "local_chat"):
# generated += token
# return generated
# generated = asyncio.run(async_inference(messages))
# his_content += [
# {"role": "user", "content": content},
# {"role": "assistant", "content": generated},
# ]
# if __name__ == "__main__":
# local_chat()
""" """
Description : Description :
Author : Boxin Zhang, Azure-Tang Author : Boxin Zhang, Azure-Tang
Version : 0.1.0 Version : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved. Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
""" """
import asyncio
import os import os
import platform import platform
import sys import sys
project_dir = os.path.dirname(os.path.dirname(__file__)) project_dir = os.path.dirname(os.path.dirname(__file__))
sys.path.insert(0, project_dir) sys.path.insert(0, project_dir)
from ktransformers.server.args import ArgumentParser import torch
import logging
from transformers import (
AutoTokenizer,
AutoConfig,
AutoModelForCausalLM,
GenerationConfig,
TextStreamer,
)
import json
import fire
from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
from ktransformers.models.modeling_llama import LlamaForCausalLM from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.util.utils import prefill_and_generate
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
custom_models = { custom_models = {
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM, "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
"DeepseekV3ForCausalLM": DeepseekV3ForCausalLM,
"Qwen2MoeForCausalLM": Qwen2MoeForCausalLM, "Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM, "LlamaForCausalLM": LlamaForCausalLM,
"MixtralForCausalLM": MixtralForCausalLM, "MixtralForCausalLM": MixtralForCausalLM,
} }
ktransformer_rules_dir = os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/" ktransformer_rules_dir = (
os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/"
)
default_optimize_rules = { default_optimize_rules = {
"DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml", "DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml",
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat.yaml",
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml", "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml",
"LlamaForCausalLM": ktransformer_rules_dir + "Internlm2_5-7b-Chat-1m.yaml", "LlamaForCausalLM": ktransformer_rules_dir + "Internlm2_5-7b-Chat-1m.yaml",
"MixtralForCausalLM": ktransformer_rules_dir + "Mixtral.yaml", "MixtralForCausalLM": ktransformer_rules_dir + "Mixtral.yaml",
} }
def local_chat(): def local_chat(
config = Config() model_path: str | None = None,
arg_parser = ArgumentParser(config) optimize_rule_path: str = None,
# 初始化消息 gguf_path: str | None = None,
arg_parser.parse_args() max_new_tokens: int = 1000,
if config.backend_type == "transformers": cpu_infer: int = Config().cpu_infer,
from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface use_cuda_graph: bool = True,
elif config.backend_type == "exllamav2": prompt_file : str | None = None,
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface mode: str = "normal",
elif config.backend_type == "ktransformers": ):
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface
torch.set_grad_enabled(False)
Config().cpu_infer = cpu_infer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
if mode == 'long_context':
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
torch.set_default_dtype(torch.float16)
else: else:
raise NotImplementedError(f"{config.backend_type} not implemented") torch.set_default_dtype(config.torch_dtype)
interface = BackendInterface(config)
with torch.device("meta"):
if config.architectures[0] in custom_models:
print("using custom modeling_xxx.py.")
if (
"Qwen2Moe" in config.architectures[0]
): # Qwen2Moe must use flash_attention_2 to avoid overflow.
config._attn_implementation = "flash_attention_2"
if "Llama" in config.architectures[0]:
config._attn_implementation = "eager"
if "Mixtral" in config.architectures[0]:
config._attn_implementation = "flash_attention_2"
model = custom_models[config.architectures[0]](config)
else:
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=True, attn_implementation="flash_attention_2"
)
if optimize_rule_path is None:
if config.architectures[0] in default_optimize_rules:
print("using default_optimize_rule for", config.architectures[0])
optimize_rule_path = default_optimize_rules[config.architectures[0]]
else:
optimize_rule_path = input(
"please input the path of your rule file(yaml file containing optimize rules):"
)
if gguf_path is None:
gguf_path = input(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
)
optimize_and_load_gguf(model, optimize_rule_path, gguf_path, config)
try:
model.generation_config = GenerationConfig.from_pretrained(model_path)
except:
gen_config = GenerationConfig(
max_length=128,
temperature=0.7,
top_p=0.9,
do_sample=True
)
model.generation_config = gen_config
# model.generation_config = GenerationConfig.from_pretrained(model_path)
if model.generation_config.pad_token_id is None:
model.generation_config.pad_token_id = model.generation_config.eos_token_id
model.eval()
logging.basicConfig(level=logging.INFO)
system = platform.system() system = platform.system()
if system == "Windows": if system == "Windows":
os.system("cls") os.system("cls")
else: else:
os.system("clear") os.system("clear")
# add a history chat content
his_content = []
while True: while True:
content = input("Chat: ") content = input("Chat: ")
if content.startswith('"""'): # prefix """ if content.startswith('"""'): # prefix """
...@@ -73,27 +251,28 @@ def local_chat(): ...@@ -73,27 +251,28 @@ def local_chat():
break break
else: else:
content += line + "\n" content += line + "\n"
if content == "": if content == "":
if config.prompt_file == None or config.prompt_file == "": if prompt_file != None:
content = "Please write a piece of quicksort code in C++." content = open(prompt_file, "r").read()
else: else:
content = open(config.prompt_file, "r").read() content = "Please write a piece of quicksort code in C++."
elif os.path.isfile(content): elif os.path.isfile(content):
content = open(content, "r").read() content = open(content, "r").read()
messages = his_content + [{"role": "user", "content": content}] messages = [{"role": "user", "content": content}]
input_tensor = tokenizer.apply_chat_template(
async def async_inference(messages): messages, add_generation_prompt=True, return_tensors="pt"
generated = "" )
async for token in interface.inference(messages, "local_chat"): if mode == 'long_context':
generated += token assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
return generated "please change max_seq_len in ~/.ktransformers/config.yaml"
torch.set_default_dtype(
generated = asyncio.run(async_inference(messages)) torch.bfloat16
his_content += [ ) # TODO: Remove this, replace dtype using config
{"role": "user", "content": content}, generated = prefill_and_generate(
{"role": "assistant", "content": generated}, model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode
] )
if __name__ == "__main__": if __name__ == "__main__":
local_chat() fire.Fire(local_chat)
\ No newline at end of file
# coding=utf-8
# Copyright 2025 bzantium and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on the DeepSeekV3 implementations from the DeepSeek AI team. (https://huggingface.co/deepseek-ai/DeepSeek-V3)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DeepSeekV3 model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
class DeepseekV3Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the DeepSeek-V3.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 129280):
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`DeepseekV3Model`]
hidden_size (`int`, *optional*, defaults to 7168):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 18432):
Dimension of the MLP representations.
moe_intermediate_size (`int`, *optional*, defaults to 2048):
Dimension of the MoE representations.
num_hidden_layers (`int`, *optional*, defaults to 61):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 128):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 128):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
n_shared_experts (`int`, *optional*, defaults to 1):
Number of shared experts.
n_routed_experts (`int`, *optional*, defaults to 256):
Number of routed experts.
routed_scaling_factor (`float`, *optional*, defaults to 2.5):
Scaling factor or routed experts.
kv_lora_rank (`int`, *optional*, defaults to 512):
Rank of the LoRA matrices for key and value projections.
q_lora_rank (`int`, *optional*, defaults to 1536):
Rank of the LoRA matrices for query projections.
qk_rope_head_dim (`int`, *optional*, defaults to 64):
Dimension of the query/key heads that use rotary position embeddings.
v_head_dim (`int`, *optional*, defaults to 128):
Dimension of the value heads.
qk_nope_head_dim (`int`, *optional*, defaults to 128):
Dimension of the query/key heads that don't use rotary position embeddings.
n_group (`int`, *optional*, defaults to 8):
Number of groups for routed experts.
topk_group (`int`, *optional*, defaults to 4):
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
num_experts_per_tok (`int`, *optional*, defaults to 8):
Number of selected experts, None means dense model.
first_k_dense_replace (`int`, *optional*, defaults to 3):
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
\--k dense layers--/
norm_topk_prob (`bool`, *optional*, defaults to `True`):
Whether to normalize the weights of the routed experts.
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
Auxiliary loss weight coefficient.
Whether to compute the auxiliary loss for each individual sample.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 4096):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 0):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 1):
End of stream token id.
pretraining_tp (`int`, *optional*, defaults to 1):
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
issue](https://github.com/pytorch/pytorch/issues/76232).
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import DeepseekV3Model, DeepseekV3Config
>>> # Initializing a Deepseek-V3 style configuration
>>> configuration = DeepseekV3Config()
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "deepseek_v3"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `DeepseekV3Model`
base_model_tp_plan = {
"layers.*.gate_proj": "colwise",
"layers.*.up_proj": "colwise",
"layers.*.down_proj": "rowwise",
}
def __init__(
self,
vocab_size=129280,
hidden_size=7168,
intermediate_size=18432,
moe_intermediate_size=2048,
num_hidden_layers=61,
num_attention_heads=128,
num_key_value_heads=128,
n_shared_experts=1,
n_routed_experts=256,
routed_scaling_factor=2.5,
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=64,
v_head_dim=128,
qk_nope_head_dim=128,
n_group=8,
topk_group=4,
num_experts_per_tok=8,
first_k_dense_replace=3,
norm_topk_prob=True,
aux_loss_alpha=0.001,
hidden_act="silu",
max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=0,
eos_token_id=1,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.routed_scaling_factor = routed_scaling_factor
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.q_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.head_dim = qk_rope_head_dim
self.n_group = n_group
self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok
self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob
self.aux_loss_alpha = aux_loss_alpha
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, copy it it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
__all__ = ["DeepseekV3Config"]
\ No newline at end of file
...@@ -34,9 +34,12 @@ class StaticCache(transformers.StaticCache): ...@@ -34,9 +34,12 @@ class StaticCache(transformers.StaticCache):
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = ( if config.architectures[0] == "DeepseekV3ForCausalLM":
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads self.head_dim = config.qk_rope_head_dim
) else:
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)
self.dtype = dtype if dtype is not None else torch.float32 self.dtype = dtype if dtype is not None else torch.float32
self.num_key_value_heads = ( self.num_key_value_heads = (
...@@ -46,7 +49,7 @@ class StaticCache(transformers.StaticCache): ...@@ -46,7 +49,7 @@ class StaticCache(transformers.StaticCache):
self.key_cache: List[torch.Tensor] = [] self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = []
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
if config.architectures[0] == "DeepseekV2ForCausalLM": if config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM":
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically # TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
# key_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.qk_rope_head_dim + config.qk_nope_head_dim) # key_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.qk_rope_head_dim + config.qk_nope_head_dim)
# value_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.v_head_dim) # value_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.v_head_dim)
...@@ -132,3 +135,7 @@ class StaticCache(transformers.StaticCache): ...@@ -132,3 +135,7 @@ class StaticCache(transformers.StaticCache):
# In-place ops prevent breaking the static address # In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_() self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_()
def get_max_cache_shape(self) -> Tuple[int, int, int, int]:
"""Returns the maximum shape of the cache."""
return self.max_cache_len
\ No newline at end of file
This diff is collapsed.
...@@ -12,15 +12,21 @@ from ktransformers.models.modeling_llama import ( ...@@ -12,15 +12,21 @@ from ktransformers.models.modeling_llama import (
LlamaLinearScalingRotaryEmbedding, LlamaLinearScalingRotaryEmbedding,
LlamaDynamicNTKScalingRotaryEmbedding, LlamaDynamicNTKScalingRotaryEmbedding,
) )
from ktransformers.models.modeling_deepseek_v3 import (
DeepseekV3RotaryEmbedding
)
from ktransformers.models.modeling_deepseek import ( from ktransformers.models.modeling_deepseek import (
DeepseekV2YarnRotaryEmbedding, DeepseekV2YarnRotaryEmbedding,
DeepseekV2RotaryEmbedding, DeepseekV2RotaryEmbedding,
yarn_get_mscale,
yarn_linear_ramp_mask,
yarn_find_correction_range
) )
from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.util.utils import InferenceState from ktransformers.util.utils import InferenceState
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
import torch
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe
class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding): class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
...@@ -53,6 +59,57 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding): ...@@ -53,6 +59,57 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
) )
class RotaryEmbeddingV3(BaseInjectedModule):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
# device: str = "cuda",
generate_device: str = "cuda",
prefill_device: str = "cuda",
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
)
self.generate_device = generate_device
self.prefill_device = prefill_device
@torch.no_grad()
def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def load(self):
self._init(
dim=self.config.qk_rope_head_dim,
max_position_embeddings=self.config.max_position_embeddings,
base=self.config.rope_theta,
device=self.device,
)
def _init(self, dim, max_position_embeddings, base, device, scaling_factor=1.0):
self.scaling_factor = scaling_factor
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
# self.register_buffer("inv_freq", inv_freq, persistent=False)
# For BC we register cos and sin cached
self.max_seq_len_cached = max_position_embeddings
class RotaryEmbeddingV2(BaseInjectedModule, LlamaRotaryEmbedding): class RotaryEmbeddingV2(BaseInjectedModule, LlamaRotaryEmbedding):
def __init__( def __init__(
self, self,
...@@ -134,6 +191,137 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding): ...@@ -134,6 +191,137 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
self.orig_module.mscale_all_dim, self.orig_module.mscale_all_dim,
) )
# class DeepSeekV3YarnRotaryEmbedding(BaseInjectedModule, DeepseekV3RotaryEmbedding):
# def __init__(
# self,
# key: str,
# gguf_loader: GGUFLoader,
# config: PretrainedConfig,
# orig_module: nn.Module,
# # device: str = "cuda",
# generate_device: str = "cuda",
# prefill_device: str = "cuda",
# **kwargs,
# ):
# BaseInjectedModule.__init__(
# self, key, gguf_loader, config, orig_module, generate_device, **kwargs
# )
# self.generate_device = generate_device
# self.prefill_device = prefill_device
# def load(self):
# # TODO support perlayer prefill
# self.orig_module.__init__(
# self.config,
# device=self.generate_device
# )
# return
class YarnRotaryEmbeddingV3(BaseInjectedModule):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
# device: str = "cuda",
generate_device: str = "cuda",
prefill_device: str = "cuda",
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
)
self.generate_device = generate_device
self.prefill_device = prefill_device
def load(self):
kwargs = {
key: self.config.rope_scaling[key]
for key in [
"original_max_position_embeddings",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
]
if key in self.config.rope_scaling
}
self._init(
dim=self.config.qk_rope_head_dim,
max_position_embeddings=self.config.max_position_embeddings,
base=self.config.rope_theta,
device=self.device,
scaling_factor=self.config.rope_scaling["factor"],
**kwargs,
)
@torch.no_grad()
def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()* self._mscale
sin = emb.sin()* self._mscale
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def _init(
self,
dim,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
original_max_position_embeddings=4096,
beta_fast=32,
beta_slow=1,
mscale=1,
mscale_all_dim=0,
):
self.original_max_position_embeddings = original_max_position_embeddings
self.beta_fast = beta_fast
self.beta_slow = beta_slow
self.mscale = mscale
self.mscale_all_dim = mscale_all_dim
self.scaling_factor = scaling_factor
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
freq_extra = 1.0 / (
self.base
** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
)
freq_inter = 1.0 / (
self.scaling_factor
* self.base
** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
)
low, high = yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
dim,
self.base,
self.original_max_position_embeddings,
)
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
device=device, dtype=torch.float32
)
self.inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
self._mscale = float(
yarn_get_mscale(self.scaling_factor, self.mscale)
/ yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
)
# For BC we register cos and sin cached
self.max_seq_len_cached = max_position_embeddings
class DynamicNTKScalingRotaryEmbedding( class DynamicNTKScalingRotaryEmbedding(
BaseInjectedModule, LlamaDynamicNTKScalingRotaryEmbedding BaseInjectedModule, LlamaDynamicNTKScalingRotaryEmbedding
......
...@@ -13,6 +13,8 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config ...@@ -13,6 +13,8 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config
from ktransformers.models.configuration_llama import LlamaConfig from ktransformers.models.configuration_llama import LlamaConfig
from ktransformers.models.modeling_llama import LlamaRotaryEmbedding from ktransformers.models.modeling_llama import LlamaRotaryEmbedding
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3Attention
from ktransformers.models.modeling_deepseek_v3 import apply_rotary_pos_emb as apply_rotary_pos_emb_v3
from typing import Optional, Tuple from typing import Optional, Tuple
from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.custom_gguf import GGUFLoader
...@@ -20,6 +22,206 @@ import logging ...@@ -20,6 +22,206 @@ import logging
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
logger = logging.getLogger("attention") logger = logging.getLogger("attention")
class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
attn_mask: Optional[torch.Tensor] = None
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
device: str = "cuda",
chunck_size: int = 1000,
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
self.orig_module.__init__(orig_module.config,
orig_module.layer_idx)
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
self.softmax_scale = self.q_head_dim ** (-0.5)
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank)
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank)
self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim,
bias=False, dtype=q_absorb.dtype, device=q_absorb.device)
self.q_absorb.weight.data = q_absorb
self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim,
bias=False, dtype=out_absorb.dtype, device=out_absorb.device)
self.out_absorb.weight.data = out_absorb
del self.orig_module.kv_b_proj
q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank)
return q_absorb, out_absorb
def forward_chunck(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
if self.q_lora_rank is None:
q = self.q_proj(hidden_states)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, k_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
compressed_kv = self.kv_a_layernorm(compressed_kv)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
kv_seq_len = k_pe.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(q_pe, position_ids)
q_pe, k_pe = apply_rotary_pos_emb_v3(q_pe, k_pe, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
compressed_kv = compressed_kv.unsqueeze(1)
k_pe, compressed_kv = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs)
compressed_kv = compressed_kv.squeeze(1)
#if cache_position is not None:
# compressed_kv = compressed_kv[:,: cache_position[-1] + 1,:]
# k_pe = k_pe[:,:,: cache_position[-1] + 1,:]
q_absorb, out_absorb = self.get_absorbed()
q_nope = torch.matmul(q_nope, q_absorb)
attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)) * self.softmax_scale
"""
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
assert attention_mask is not None
"""
if attention_mask is not None:
"""
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
"""
#causal_mask = attention_mask[:, :, :, : kv_seq_len]
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(q_pe.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.attention_dropout, training=self.training
)
attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)
attn_output = torch.matmul(attn_output, out_absorb.mT)
if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
if q_len <= self.chunck_size:
return self.forward_chunck(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
**kwargs
)
assert output_attentions == False, "output_attentions is not supported when using chunked attention"
attn_output = None
attn_weight = None
cur_idx = 0
while cur_idx < q_len:
if attention_mask is not None:
chunk_mask = attention_mask[:, :, cur_idx:min(cur_idx + self.chunck_size, q_len), ...]
else:
# generate chunk_mask automatically.
self.attn_mask = \
torch.zeros(1, 1, self.chunck_size, past_key_value.max_cache_len, device=hidden_states.device) \
if self.attn_mask is None \
else self.attn_mask
self.attn_mask[:, :, :, cur_idx:min(cur_idx+self.chunck_size, past_key_value.max_cache_len)] = \
-1e+38 * torch.triu(torch.ones(self.chunck_size, self.chunck_size, device=hidden_states.device), diagonal=1)\
[:,:min(self.chunck_size, min(past_key_value.max_cache_len-cur_idx, self.chunck_size))]
self.attn_mask[:, :, :, cur_idx+self.chunck_size:] = -1e+38
self.attn_mask[:, :, :, :cur_idx] = 0
chunk_mask = torch.narrow(self.attn_mask, 2, 0, min(self.chunck_size, q_len-cur_idx))
cur_output, cur_attn_weight = self.forward_chunck(
hidden_states[:, cur_idx:min(cur_idx + self.chunck_size, q_len), ...],
chunk_mask,
position_ids[:, cur_idx:min(cur_idx + self.chunck_size, q_len)],
past_key_value,
output_attentions,
use_cache,
cache_position[cur_idx:min(cur_idx + self.chunck_size, q_len)],
**kwargs
)
cur_idx += self.chunck_size
if attn_output is None:
attn_output = cur_output
attn_weight = cur_attn_weight
else:
attn_output = torch.cat((attn_output, cur_output), dim=-2)
attn_weight = torch.cat((attn_weight, cur_attn_weight), dim=-2)
return attn_output, attn_weight, past_key_value
class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
attn_mask: Optional[torch.Tensor] = None attn_mask: Optional[torch.Tensor] = None
......
...@@ -302,13 +302,13 @@ class KExpertsMarlin(KExpertsBase): ...@@ -302,13 +302,13 @@ class KExpertsMarlin(KExpertsBase):
if w is None: w = self.load_weights()[self.key] if w is None: w = self.load_weights()[self.key]
if isinstance(w, dict): if isinstance(w, dict):
self.gate = nn.Parameter(torch.from_numpy(w["gate"])) self.gate = w["gate"]
self.up = nn.Parameter(torch.from_numpy(w["up"])) self.up = (w["up"])
self.down = nn.Parameter(torch.from_numpy(w["down"])) self.down = (w["down"])
for i in range(self.expert_num): for i in range(self.expert_num):
self.up_projs[i].load(self.up[i,...], device=device) self.up_projs[i].load(nn.Parameter(self.up[i,...]), device=device)
self.gate_projs[i].load(self.gate[i,...], device=device) self.gate_projs[i].load(nn.Parameter(self.gate[i,...]), device=device)
self.down_projs[i].load(self.down[i,...], device=device) self.down_projs[i].load(nn.Parameter(self.down[i,...]), device=device)
self.loaded_experts_idx.append(i) self.loaded_experts_idx.append(i)
return return
...@@ -342,23 +342,45 @@ class KExpertsMarlin(KExpertsBase): ...@@ -342,23 +342,45 @@ class KExpertsMarlin(KExpertsBase):
up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"] up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"]
down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"] down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"]
# tensors = self.load_multi(key, [".ffn_gate_exps.weight", ".ffn_up_exps.weight", ".ffn_down_exps.weight"]) # tensors = self.load_multi(key, [".ffn_gate_exps.weight", ".ffn_up_exps.weight", ".ffn_down_exps.weight"])
res = {key:{"gate": gate, "up": up, "down": down, "gate_type": gate_type, "up_type": up_type, "down_type": down_type}} res = {key:{"gate": nn.Parameter(gate), "up": nn.Parameter(up), "down": nn.Parameter(down), "gate_type": gate_type, "up_type": up_type, "down_type": down_type}}
return res return res
def forward(self, input_tensor:torch.Tensor, expert_ids, weights): def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:
# forward org_dtype = hidden_states_cpu.dtype
device = input_tensor.device org_device = hidden_states_cpu.device
input_tensor = input_tensor.to("cuda") hidden_states_cpu = hidden_states_cpu.to(self.device)
outs = torch.zeros_like(input_tensor) selected_experts_cpu = selected_experts_cpu.to(self.device)
for expert_idx in range(expert_ids.size(0)): routing_weights_cpu = routing_weights_cpu.to(self.device).to(org_dtype)
down_proj = self.down_projs[expert_idx]
gate_proj = self.gate_projs[expert_idx] batch_sequence_length, hidden_dim = hidden_states_cpu.size()
up_proj = self.up_projs[expert_idx]
outs += down_proj(self.act_fn(gate_proj(input_tensor)) * up_proj(input_tensor)) * weights[expert_idx]
outs = outs.to(device)
return outs
final_hidden_states = torch.zeros(
(batch_sequence_length, hidden_dim), dtype=hidden_states_cpu.dtype, device=hidden_states_cpu.device
)
# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(selected_experts_cpu, num_classes=self.expert_num).permute(2, 1, 0)
# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.expert_num):
if not expert_mask[expert_idx].any():
continue
idx, top_x = torch.where(expert_mask[expert_idx])
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states_cpu[None, top_x].reshape(-1, hidden_dim)
G = self.gate_projs[expert_idx].forward(current_state)
A = self.act_fn(G)
U = self.up_projs[expert_idx].forward(current_state)
H = A * U # Element-wise multiplication
current_hidden_states = self.down_projs[expert_idx].forward(H) * routing_weights_cpu[top_x, idx, None]
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states)
return final_hidden_states.to(dtype=org_dtype, device=org_device)
class KExpertsTorch(KExpertsBase): class KExpertsTorch(KExpertsBase):
expert_num: int expert_num: int
loaded_experts_idx: list[int] loaded_experts_idx: list[int]
...@@ -519,6 +541,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase): ...@@ -519,6 +541,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase):
from ktransformers.models.modeling_deepseek import DeepseekV2MoE from ktransformers.models.modeling_deepseek import DeepseekV2MoE
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MoE
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock
...@@ -727,6 +750,107 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE): ...@@ -727,6 +750,107 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
) )
return final_out return final_out
class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
def forward(self, hidden_states):
identity = hidden_states
orig_shape = hidden_states.shape
sequence_length = orig_shape[1]
topk_idx, topk_weight = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
# only for generate phase
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing():
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity).squeeze(0)
y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0)
y += y_
y.resize_(*orig_shape)
return y
if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity).squeeze(0)
if isinstance(self.experts, KExpertsBase):
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)
elif hidden_states.size(0) > 10:
# TODO may bugs here
y = (
self.moe_infer(hidden_states, topk_idx, topk_weight)
.view(*orig_shape)
.to(device=hidden_states.device)
)
else:
# TODO may bugs here
y = (
self.moe_infer_simple(hidden_states, topk_idx, topk_weight)
.view(*orig_shape)
.to(device=hidden_states.device)
)
if self.config.n_shared_experts is not None:
y += y_
return y
@torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = torch.empty_like(x)
outs = self.experts(x, topk_ids, topk_weight)
return outs
@torch.no_grad()
# TODO may bugs here
def moe_infer_simple(
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
) -> torch.Tensor:
"""
x: [num_tokens, hidden_size]
topk_ids, topk_weight: [num_tokens, num_selected_experts]
"""
outs = torch.zeros_like(x)
for token_idx in range(topk_ids.size(0)):
for expert_idx in range(topk_ids.size(1)):
expert = self.experts[topk_ids[token_idx, expert_idx]]
outs[token_idx] += (
expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]
)
return outs
@torch.no_grad()
# TODO may bugs here
def moe_infer(self, x, topk_ids, topk_weight):
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
cnts.scatter_(1, topk_ids, 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = topk_ids.view(-1).argsort()
sorted_tokens = x[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = expert.forward(tokens_for_this_expert)
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (
new_x.view(*topk_ids.shape, -1)
.type(topk_weight.dtype)
.mul_(topk_weight.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
return final_out
class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock): class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment