Commit a5f106eb authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-marlin' into 'v0.9.2-dev'

fix: 优化w4a8 marlin 中 weight重排耗时

See merge request dcutoolkit/deeplearing/vllm!200
parents bc9aee38 ffab74dd
......@@ -54,12 +54,12 @@ def marlin_weights(q_w,weight_perm,k_tile=32,n_tile=64,pack_factor=8):
q_w = q_w.reshape((-1, weight_perm.numel()))[:, weight_perm].reshape(q_w.shape)
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(np.uint32)
q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
q_w = q_w.contiguous().to(torch.int32)
M, N = q_w.shape
assert N % pack_factor == 0, f"size_n ({N}) must be divisible by pack_factor ({pack_factor})"
q_packed = torch.zeros((M, N // pack_factor), dtype=torch.int32, device=orig_device)
for i in range(pack_factor):
q_packed |= q_w[:, i::pack_factor] << 4 * i
q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
q_packed += q_w[:, i::pack_factor] << (4 * i)
return q_packed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment