Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
7a19f3b7
"doc/vscode:/vscode.git/clone" did not exist on "af9472b5180ce46d3bd907b57a00c076f41ac160"
Unverified
Commit
7a19f3b7
authored
Feb 27, 2025
by
wang jiahao
Committed by
GitHub
Feb 27, 2025
Browse files
Merge pull request #721 from kvcache-ai/fix_temperature
fix temperature
parents
85e2cc7b
22df52e9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
16 deletions
+16
-16
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+12
-11
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+4
-5
No files found.
ktransformers/server/backend/interfaces/ktransformers.py
View file @
7a19f3b7
...
...
@@ -29,6 +29,16 @@ class KTransformersInterface(TransformersInterface):
torch
.
set_grad_enabled
(
False
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_dir
,
device
=
args
.
device
,
trust_remote_code
=
args
.
trust_remote_code
)
config
=
AutoConfig
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
args
.
trust_remote_code
)
try
:
generation_config
=
GenerationConfig
.
from_pretrained
(
args
.
model_dir
)
except
:
generation_config
=
GenerationConfig
(
max_length
=
args
.
max_new_tokens
,
temperature
=
args
.
temperature
,
top_p
=
args
.
temperature
,
do_sample
=
True
)
torch
.
set_default_dtype
(
config
.
torch_dtype
)
if
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
:
config
.
_attn_implementation
=
"flash_attention_2"
...
...
@@ -49,7 +59,7 @@ class KTransformersInterface(TransformersInterface):
" belong to current model):"
)
optimize_and_load_gguf
(
self
.
model
,
optimize_config_path
,
gguf_path
,
config
)
self
.
model
.
generation_config
=
generation_config
self
.
device_map
=
self
.
model
.
gguf_loader
.
tensor_device_map
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
self
.
cache
=
StaticCache
(
...
...
@@ -60,16 +70,7 @@ class KTransformersInterface(TransformersInterface):
dtype
=
self
.
model
.
dtype
,
)
# logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}")
try
:
self
.
model
.
generation_config
=
GenerationConfig
.
from_pretrained
(
args
.
model_dir
)
except
:
gen_config
=
GenerationConfig
(
max_length
=
128
,
temperature
=
0.7
,
top_p
=
0.9
,
do_sample
=
True
)
self
.
model
.
generation_config
=
gen_config
if
self
.
model
.
generation_config
.
pad_token_id
is
None
:
self
.
model
.
generation_config
.
pad_token_id
=
self
.
model
.
generation_config
.
eos_token_id
self
.
streamer
=
TextStreamer
(
self
.
tokenizer
)
...
...
ktransformers/server/backend/interfaces/transformers.py
View file @
7a19f3b7
...
...
@@ -203,10 +203,10 @@ class TransformersInterface(BackendInterfaceBase):
return
self
.
streamer
.
put
(
new_tokens
)
def
prepare_logits_wrapper
(
self
,
inputs
,
device
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
if
temperature
is
None
:
temperature
=
self
.
args
.
temperature
if
temperature
is
None
or
temperature
==
0
:
temperature
=
self
.
model
.
generation_config
.
temperature
if
top_p
is
None
:
top_p
=
self
.
args
.
top_p
top_p
=
self
.
model
.
generation_config
.
top_p
generation_config
,
model_kwargs
=
self
.
model
.
_prepare_generation_config
(
None
,
max_length
=
self
.
args
.
max_new_tokens
,
do_sample
=
True
,
...
...
@@ -216,10 +216,9 @@ class TransformersInterface(BackendInterfaceBase):
repetition_penalty
=
self
.
args
.
repetition_penalty
# change this to modify generate config
)
self
.
inputs
=
inputs
self
.
generation_config
=
generation_config
try
:
# transformers==4.43
self
.
logits_warper
=
(
self
.
model
.
_get_logits_warper
(
generation_config
,
device
=
device
)
self
.
model
.
_get_logits_warper
(
generation_config
,
device
=
device
)
)
except
:
self
.
logits_warper
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment