Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Bw-bestperf
pd-test
Commits
e328df95
Commit
e328df95
authored
Feb 06, 2026
by
jerrrrry
Browse files
Upload New File
parent
dd80c5ce
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
189 additions
and
0 deletions
+189
-0
bf16_2_w4a8/bf16_cast_channel_int4_v2.py
bf16_2_w4a8/bf16_cast_channel_int4_v2.py
+189
-0
No files found.
bf16_2_w4a8/bf16_cast_channel_int4_v2.py
0 → 100644
View file @
e328df95
import
os
import
json
from
argparse
import
ArgumentParser
from
glob
import
glob
from
tqdm
import
tqdm
import
torch
from
safetensors.torch
import
load_file
,
save_file
from
huggingface_hub
import
snapshot_download
import
numpy
as
np
import
matplotlib.pyplot
as
plt
def
get_plot
(
matrix
:
torch
.
Tensor
):
n_rows
=
matrix
.
shape
[
0
]
row_labels
=
[
f
"Row_
{
i
}
"
for
i
in
range
(
n_rows
)]
# 为每行生成一个独立的图片并保存
for
i
in
range
(
n_rows
):
plt
.
figure
(
figsize
=
(
8
,
4
))
plt
.
hist
(
matrix
[
i
,
:],
bins
=
20
,
alpha
=
0.7
,
color
=
'green'
)
plt
.
title
(
f
"Distribution of
{
row_labels
[
i
]
}
"
)
plt
.
xlabel
(
"Value"
)
plt
.
ylabel
(
"Frequency"
)
plt
.
savefig
(
f
"./result/row_
{
i
}
_histogram.png"
)
# 保存为PNG
plt
.
close
()
def
weight_quant
(
tensor
:
torch
.
Tensor
):
assert
tensor
.
dim
()
==
2
qmax
=
127.0
#-127 到 127
abs_max
=
torch
.
abs
(
tensor
).
max
(
dim
=
1
,
keepdim
=
True
)[
0
]
# [rows, 1]
scale
=
abs_max
/
qmax
# [rows, 1]
assert
scale
.
shape
==
(
tensor
.
shape
[
0
],
1
)
quantized
=
torch
.
round
(
tensor
/
scale
)
quantized
=
torch
.
clamp
(
quantized
,
-
qmax
,
qmax
)
return
quantized
.
to
(
torch
.
int8
),
scale
.
to
(
torch
.
float32
)
def
weight_quantint4
(
tensor
:
torch
.
Tensor
):
assert
tensor
.
dim
()
==
2
qmax
=
7.0
#-7 到 7
#求绝对值
abs_value
=
torch
.
abs
(
tensor
)
#对绝对值进行排序
sorted_matrix
,
_
=
torch
.
sort
(
abs_value
,
dim
=
1
)
k
=
tensor
.
shape
[
1
]
index
=
int
(
k
*
0.95
)
abs_max
=
sorted_matrix
[:,
index
].
reshape
(
-
1
,
1
)
# print("abs_max:",abs_max)
# print("abs_max.shape:",abs_max.shape)
#abs_max = torch.abs(tensor).max(dim=1, keepdim=True)[0] # [rows, 1]
scale
=
abs_max
/
qmax
# [rows, 1]
assert
scale
.
shape
==
(
tensor
.
shape
[
0
],
1
)
#量化
quantized
=
torch
.
round
(
tensor
/
scale
)
quantized
=
torch
.
clamp
(
quantized
,
-
qmax
,
qmax
).
to
(
torch
.
int8
)
#quantized_int8=quantized+8
print
(
"quantized_int8:"
,
quantized
)
negative_mask
=
quantized
<
0
# 2. 对负数取绝对值并加 8
#quantized[negative_mask] = torch.abs(quantized[negative_mask]) + 16
quantized
[
negative_mask
]
=
quantized
[
negative_mask
]
#+ 8
print
(
"quantized_int8_2:"
,
quantized
)
quantized_int8
=
quantized
.
to
(
torch
.
uint8
)
n
,
k
=
quantized
.
size
()
new_shape
=
(
n
,
k
//
2
)
quantized_int4
=
torch
.
empty
(
new_shape
,
dtype
=
torch
.
int8
,
device
=
tensor
.
device
)
a
=
quantized_int8
[...,
::
2
]
b
=
quantized_int8
[...,
1
::
2
]
a_4bit
=
a
#& 0x0F
b_4bit
=
b
&
0x0F
quantized_int4
=
(
a_4bit
<<
4
)
|
b_4bit
quantized_int4
=
quantized_int4
.
contiguous
().
to
(
torch
.
int8
)
# print("quantized_int4.shape:",quantized_int4.shape)
# quantized_int4_2=torch.repeat_interleave(quantized_int4, repeats=2, dim=-1)
# a1= quantized_int4_2[..., ::2]
# b1= quantized_int4_2[..., 1::2]
# print("a1:",a1)
# print("b1:",b1)
# a2= a1 & 0xF0
# b2 = (b1<<4 ) & 0xF0
# print("a:",a2)
# print("b:",b2)
return
quantized_int4
,
scale
.
to
(
torch
.
float32
)
def
main
(
bf16_path
,
int8_path
,
model_name
=
"deepseek-ai/DeepSeek-R1"
):
torch
.
set_default_dtype
(
torch
.
bfloat16
)
os
.
makedirs
(
int8_path
,
exist_ok
=
True
)
model_index_file
=
os
.
path
.
join
(
int8_path
,
"model.safetensors.index.json"
)
config_file
=
os
.
path
.
join
(
int8_path
,
"config.json"
)
if
not
os
.
path
.
exists
(
model_index_file
)
or
not
os
.
path
.
exists
(
config_file
):
snapshot_download
(
repo_id
=
model_name
,
ignore_patterns
=
[
"*.safetensors"
],
local_dir
=
int8_path
,
local_dir_use_symlinks
=
False
)
print
(
f
"model index file and config file downloaded to
{
int8_path
}
"
)
# modify config.json and save it
config
=
json
.
load
(
open
(
config_file
))
# delete quantization_config
config
.
pop
(
"quantization_config"
,
None
)
with
open
(
config_file
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
config
,
f
,
indent
=
2
,
ensure_ascii
=
False
,
sort_keys
=
True
)
print
(
f
"config.json modified and saved to
{
config_file
}
"
)
with
open
(
model_index_file
,
"r"
)
as
f
:
model_index
=
json
.
load
(
f
)
weight_map
=
model_index
[
"weight_map"
]
scale_count
=
len
([
key
for
key
in
weight_map
.
keys
()
if
key
.
endswith
(
"_scale_inv"
)])
safetensor_files
=
list
(
glob
(
os
.
path
.
join
(
bf16_path
,
"*.safetensors"
)))
safetensor_files
.
sort
()
quant_count
=
0
new_weight_map
=
{}
for
safetensor_file
in
tqdm
(
safetensor_files
):
file_name
=
os
.
path
.
basename
(
safetensor_file
)
state_dict
=
load_file
(
safetensor_file
,
device
=
"cuda"
)
new_state_dict
=
{}
for
weight_name
,
weight
in
state_dict
.
items
():
scale_inv_name
=
f
"
{
weight_name
}
_scale_inv"
if
scale_inv_name
in
weight_map
:
print
(
"scale_inv_name:"
,
scale_inv_name
)
assert
weight
.
element_size
()
==
2
quant_count
+=
1
int8_weight
,
scale_inv
=
weight_quant
(
weight
)
new_scale_name
=
scale_inv_name
.
replace
(
"_scale_inv"
,
"_scale"
)
if
".mlp.experts."
in
weight_name
:
int4_weight
,
scale_int4
=
weight_quantint4
(
int8_weight
)
new_state_dict
[
weight_name
]
=
int4_weight
#int8_weight
new_state_dict
[
new_scale_name
]
=
scale_inv
*
scale_int4
/
16
else
:
new_state_dict
[
weight_name
]
=
int8_weight
new_state_dict
[
new_scale_name
]
=
scale_inv
new_weight_map
[
weight_name
]
=
file_name
new_weight_map
[
new_scale_name
]
=
file_name
else
:
print
(
"nonono"
)
new_state_dict
[
weight_name
]
=
weight
new_weight_map
[
weight_name
]
=
file_name
new_safetensor_file
=
os
.
path
.
join
(
int8_path
,
file_name
)
save_file
(
new_state_dict
,
new_safetensor_file
)
#assert quant_count == scale_count
print
(
f
"
{
quant_count
}
weights are quantized."
)
# modify model.safetensors.index.json
with
open
(
model_index_file
,
"r"
)
as
f
:
model_index
=
json
.
load
(
f
)
model_index
[
"weight_map"
]
=
new_weight_map
with
open
(
model_index_file
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
model_index
,
f
,
indent
=
2
,
ensure_ascii
=
False
,
sort_keys
=
True
)
print
(
f
"model.safetensors.index.json modified and saved to
{
model_index_file
}
"
)
if
__name__
==
"__main__"
:
parser
=
ArgumentParser
()
parser
.
add_argument
(
"--input-bf16-hf-path"
,
type
=
str
,
default
=
"/dataset/llm-models/deepseek-r1/DeepSeek-R1-0528-bf16"
)
parser
.
add_argument
(
"--output-int8-hf-path"
,
type
=
str
,
default
=
"/FrameWork/0307/3/modeltrans/DeepSeek-R1-0528-SlimQuant-W4A8"
)
parser
.
add_argument
(
"--model-name"
,
type
=
str
,
default
=
"deepseek-ai/DeepSeek-R1"
)
args
=
parser
.
parse_args
()
main
(
args
.
input_bf16_hf_path
,
args
.
output_int8_hf_path
,
args
.
model_name
)
print
(
"done"
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment