Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
91c16192
Commit
91c16192
authored
Feb 25, 2025
by
Azure
Browse files
Merge branch 'develop-0.2.2' into support-fp8
Update README.md
parents
2c0cce90
d9b2895b
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
35 additions
and
8 deletions
+35
-8
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+5
-0
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+1
-1
ktransformers/tests/mmlu_pro_test.py
ktransformers/tests/mmlu_pro_test.py
+2
-2
ktransformers/util/custom_gguf.py
ktransformers/util/custom_gguf.py
+2
-0
ktransformers/util/utils.py
ktransformers/util/utils.py
+24
-5
setup.py
setup.py
+1
-0
No files found.
ktransformers/server/backend/interfaces/ktransformers.py
View file @
91c16192
...
...
@@ -14,6 +14,7 @@ from ktransformers.models.custom_cache import StaticCache
from
ktransformers.util.cuda_graph_runner
import
CUDAGraphRunner
from
ktransformers.local_chat
import
custom_models
,
default_optimize_rules
from
ktransformers.util.utils
import
get_device
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
,
MLAWrapperSingleton
warm_uped
=
False
...
...
@@ -186,6 +187,8 @@ class KTransformersInterface(TransformersInterface):
input_ids
=
input_ids
.
to
(
"cpu"
)
inputs_embeds
=
self
.
model
.
model
.
embed_tokens
(
input_ids
).
to
(
device
)
torch
.
cuda
.
set_device
(
device
)
if
flashinfer_enabled
:
MLAWrapperSingleton
.
need_plan_all
()
if
self
.
use_static_cache
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
...
...
@@ -198,6 +201,8 @@ class KTransformersInterface(TransformersInterface):
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
if
flashinfer_enabled
:
MLAWrapperSingleton
.
reset_buffer
()
self
.
prepare_logits_wrapper
(
input_ids
,
device
)
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
yield
self
.
append_new_tokens
(
next_token
)
...
...
ktransformers/server/backend/interfaces/transformers.py
View file @
91c16192
...
...
@@ -333,7 +333,7 @@ class TransformersInterface(BackendInterfaceBase):
for
i
in
range
(
1
,
self
.
args
.
max_new_tokens
):
with
torch
.
backends
.
cuda
.
sdp_kernel
(
enable_flash
=
False
,
enable_mem_efficient
=
False
,
enable_math
=
True
):
if
i
>
1
and
flashinfer_enabled
:
if
flashinfer_enabled
:
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
self
.
active_cache_position
.
to
(
torch
.
int32
)
+
1
,
num_heads
=
self
.
model
.
config
.
num_attention_heads
,
head_dim_ckv
=
self
.
model
.
config
.
kv_lora_rank
,
head_dim_kpe
=
self
.
model
.
config
.
qk_rope_head_dim
,
page_size
=
self
.
cache
.
page_size
,
...
...
ktransformers/tests/mmlu_pro_test.py
View file @
91c16192
...
...
@@ -173,8 +173,8 @@ if __name__ == "__main__":
parser
=
argparse
.
ArgumentParser
(
description
=
"API Generate Tester"
)
parser
.
add_argument
(
"--concurrent"
,
type
=
int
,
default
=
1000
,
help
=
"Number of concurrent evaluations"
)
parser
.
add_argument
(
"--file"
,
type
=
str
,
default
=
"TIGER-Lab/MMLU-Pro"
,
help
=
"Path to the mmlu.jsonl file"
)
parser
.
add_argument
(
"--result"
,
type
=
str
,
default
=
"./mmlu_pro.json"
,
help
=
"Path to save the result JSON file"
)
parser
.
add_argument
(
"--log"
,
type
=
str
,
default
=
"./mmlu_pro.log"
,
help
=
"Path to save the log file"
)
parser
.
add_argument
(
"--result"
,
type
=
str
,
default
=
"./mmlu_
result_
pro.json"
,
help
=
"Path to save the result JSON file"
)
parser
.
add_argument
(
"--log"
,
type
=
str
,
default
=
"./mmlu_
result_
pro.log"
,
help
=
"Path to save the log file"
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"Pro/deepseek-ai/DeepSeek-V3"
,
help
=
"Model name or path"
)
parser
.
add_argument
(
"--api_url"
,
type
=
str
,
default
=
"http://localhost:15488/v1/chat/completions"
,
help
=
"API URL"
)
# parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL")
...
...
ktransformers/util/custom_gguf.py
View file @
91c16192
...
...
@@ -330,6 +330,8 @@ class GGUFLoader:
values
=
GGML_DEQUANTIZE
[
ggml_name
](
data
)
values
=
torch
.
from_numpy
(
values
.
copy
())
if
ggml_name
==
"BF16"
:
values
=
values
.
view
(
torch
.
bfloat16
)
values
=
values
.
view
(
shape
[
-
2
::
-
1
])
return
values
...
...
ktransformers/util/utils.py
View file @
91c16192
...
...
@@ -21,6 +21,18 @@ from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
warm_uped
=
False
def
get_compute_capability
(
device
:
torch
.
device
=
None
):
if
torch
.
cuda
.
is_available
():
if
device
is
None
:
num_gpus
=
torch
.
cuda
.
device_count
()
min_compute_capability_major
=
100
for
gpu_id
in
range
(
num_gpus
):
gpu_props
=
torch
.
cuda
.
get_device_properties
(
gpu_id
)
min_compute_capability_major
=
min
(
min_compute_capability_major
,
gpu_props
.
major
)
return
min_compute_capability_major
else
:
return
torch
.
cuda
.
get_device_properties
(
device
)
def
set_module
(
model
,
submodule_key
,
module
):
tokens
=
submodule_key
.
split
(
'.'
)
sub_tokens
=
tokens
[:
-
1
]
...
...
@@ -164,6 +176,10 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
inputs_embeds
=
model
.
model
.
embed_tokens
(
inputs
.
to
(
"cpu"
))
else
:
inputs_embeds
=
model
.
model
.
embed_tokens
(
inputs
.
to
(
"cpu"
)).
to
(
torch_device
)
if
use_flashinfer_mla
:
MLAWrapperSingleton
.
update_buffer
(
past_key_values
.
max_pages
)
MLAWrapperSingleton
.
need_plan_all
()
logits
=
model
(
inputs_embeds
=
inputs_embeds
,
cache_position
=
cache_position
,
past_key_values
=
past_key_values
,
return_dict
=
False
,
use_cache
=
True
)[
0
][:,
-
1
,:].
unsqueeze
(
0
).
clone
().
to
(
torch_device
)
...
...
@@ -186,6 +202,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
else
:
next_token
=
torch
.
argmax
(
next_token_scores
,
dim
=-
1
)
first_token_time
=
time
.
time
()
-
start_time
if
use_flashinfer_mla
:
MLAWrapperSingleton
.
reset_buffer
()
prefill_count
=
seq_length
prefill_time
=
first_token_time
...
...
@@ -203,22 +222,22 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
start_time
=
time
.
time
()
for
i
in
range
(
1
,
max_new_tokens
):
if
use_flashinfer_mla
:
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
position_ids
.
squeeze
(
1
)
+
1
,
num_heads
,
head_dim_ckv
,
head_dim_kpe
,
past_key_values
.
page_size
,
q_head_dim
**
(
-
0.5
),
torch
.
bfloat16
,
torch
.
bfloat16
)
global
warm_uped
if
use_cuda_graph
and
(
(
warm_uped
==
True
and
int
(
i
)
==
1
)
or
(
warm_uped
==
False
and
int
(
i
)
==
2
)
):
warm_uped
=
True
cuda_graph_runner
=
CUDAGraphRunner
()
cuda_graph_runner
.
capture
(
model
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
torch_device
,
return_dict
=
False
,
use_cache
=
True
)
if
i
>
1
and
use_flashinfer_mla
:
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
position_ids
.
squeeze
(
1
)
+
1
,
num_heads
,
head_dim_ckv
,
head_dim_kpe
,
past_key_values
.
page_size
,
q_head_dim
**
(
-
0.5
),
torch
.
bfloat16
,
torch
.
bfloat16
)
next_token
=
decode_one_tokens
(
cuda_graph_runner
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
use_cuda_graph
).
to
(
torch_device
)
inputs
=
torch
.
cat
((
inputs
,
next_token
.
unsqueeze
(
0
)),
dim
=-
1
)
generated_ids
[:,
cache_position
]
=
next_token
.
int
()
tokens
.
append
(
int
(
next_token
))
seq_length
+=
1
if
next_token
[
0
].
item
()
==
tokenizer
.
eos_token_id
or
tokenizer
.
decode
(
next_token
)
==
'<|im_end|>'
:
if
next_token
[
0
].
item
()
==
tokenizer
.
eos_token_id
or
tokenizer
.
decode
(
next_token
.
tolist
()
)
==
'<|im_end|>'
:
print
(
stream
.
end
(),
end
=
""
,
flush
=
True
)
break
else
:
...
...
setup.py
View file @
91c16192
...
...
@@ -350,6 +350,7 @@ elif MUSA_HOME is not None:
"at::cuda"
:
"at::musa"
,
"#include <ATen/cuda/CUDAContext.h>"
:
"#include
\"
torch_musa/csrc/aten/musa/MUSAContext.h
\"
"
,
"#include <c10/cuda/CUDAGuard.h>"
:
"#include
\"
torch_musa/csrc/core/MUSAGuard.h
\"
"
,
"nv_bfloat16"
:
"mt_bfloat16"
,
}).
run
()
ops_module
=
MUSAExtension
(
'KTransformersOps'
,
[
'ktransformers/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu'
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment