Unverified Commit fb140c81 authored by senlyu163's avatar senlyu163 Committed by GitHub
Browse files

feat: add rank calculation for transformer block state dict in merge_safetensors (#663)

parent 066dd05f
...@@ -107,6 +107,8 @@ def merge_safetensors( ...@@ -107,6 +107,8 @@ def merge_safetensors(
state_dict = unquantized_part_sd state_dict = unquantized_part_sd
state_dict.update(transformer_block_sd) state_dict.update(transformer_block_sd)
rank = next((v.shape[1] for k, v in transformer_block_sd.items() if ".lora_down" in k), 32)
precision = "int4" precision = "int4"
for v in state_dict.values(): for v in state_dict.values():
assert isinstance(v, torch.Tensor) assert isinstance(v, torch.Tensor)
...@@ -130,6 +132,7 @@ def merge_safetensors( ...@@ -130,6 +132,7 @@ def merge_safetensors(
"scale_dtype": "fp8_e4m3_nan" if precision == "fp4" else None, "scale_dtype": "fp8_e4m3_nan" if precision == "fp4" else None,
"group_size": 16 if precision == "fp4" else 64, "group_size": 16 if precision == "fp4" else 64,
}, },
"rank": rank,
} }
return state_dict, { return state_dict, {
"config": Path(config_path).read_text(), "config": Path(config_path).read_text(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment