Commit 392f0c26 authored by PanZezhong's avatar PanZezhong
Browse files

support transposed weights config

parent 29a8da83
......@@ -21,6 +21,8 @@ typedef struct
{
size_t nlayer;
infiniDtype_t dt_norm, dt_mat;
// 0 if linear weights are passed as W, any other value if passed as W^T (default format in pytorch)
int transpose_linear_weights;
// [dvoc, d]
const void *input_embd;
// [d]
......
......@@ -115,6 +115,7 @@ class JiugeWeightsImpl(JiugeWeights):
torch_dt_mat=torch.float16,
torch_dt_norm=torch.float32,
ndev=1,
transpose_weight=True,
):
nlayer = meta.nlayer
nh = meta.nh
......@@ -150,13 +151,15 @@ class JiugeWeightsImpl(JiugeWeights):
if naming.output_embd() in state_dict
else naming.input_embd()
)
self.transpose_linear_weights = 1 if transpose_weight else 0
self.nlayer = nlayer
self.input_embd_tensor = state_dict[input_embd_naming].to(torch_dt_logits)
self.input_embd = self.input_embd_tensor.data_ptr()
self.output_norm_tensor = state_dict[naming.output_norm()].to(torch_dt_norm)
self.output_norm = self.output_norm_tensor.data_ptr()
self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat)
if not transpose_weight:
self.output_embd_tensor = self.output_embd_tensor.transpose(0, 1).contiguous()
self.output_embd = self.output_embd_tensor.data_ptr()
self.attn_norm_tensors = [
......@@ -191,6 +194,9 @@ class JiugeWeightsImpl(JiugeWeights):
self.qkv_tensor = [
torch.concat(qkv_slices(i)).to(torch_dt_mat) for i in range(nlayer)
]
if not transpose_weight:
for i in range(nlayer):
self.qkv_tensor[i] = self.qkv_tensor[i].reshape(ndev, (nh + 2 * nkvh) // ndev * dh, d).transpose(1, 2).contiguous()
self.qkv_tensor_ptrs = [self.qkv_tensor[i].data_ptr() for i in range(nlayer)]
self.attn_qkv = (c_void_p * nlayer)(*self.qkv_tensor_ptrs)
......@@ -228,10 +234,12 @@ class JiugeWeightsImpl(JiugeWeights):
self.attn_o_tensor = [
state_dict[naming.attn_o(i)]
.to(torch_dt_mat)
.reshape([d, ndev, nh // ndev * dh])
.transpose(0, 1)
.contiguous()
.to(torch_dt_mat)
.reshape([d, ndev, nh // ndev * dh])
.transpose(0, 1)
.contiguous()
if transpose_weight
else state_dict[naming.attn_o(i)].transpose(0, 1).to(torch_dt_mat).contiguous()
for i in range(nlayer)
]
self.attn_o_ptrs = [self.attn_o_tensor[i].data_ptr() for i in range(nlayer)]
......@@ -258,6 +266,9 @@ class JiugeWeightsImpl(JiugeWeights):
self.gate_up_tensors = [
torch.concat(gate_up_slices(i)).to(torch_dt_mat) for i in range(nlayer)
]
if not transpose_weight:
for i in range(nlayer):
self.gate_up_tensors[i] = self.gate_up_tensors[i].reshape(ndev, 2 * di // ndev, d).transpose(1, 2).contiguous()
self.gate_up_ptrs = [self.gate_up_tensors[i].data_ptr() for i in range(nlayer)]
self.ffn_gate_up = (c_void_p * nlayer)(*self.gate_up_ptrs)
......@@ -267,6 +278,8 @@ class JiugeWeightsImpl(JiugeWeights):
.reshape([d, ndev, di // ndev])
.transpose(0, 1)
.contiguous()
if transpose_weight
else state_dict[naming.down(i)].transpose(0, 1).to(torch_dt_mat).contiguous()
for i in range(nlayer)
]
self.ffn_down_ptrs = [self.ffn_down_tensor[i].data_ptr() for i in range(nlayer)]
......@@ -292,12 +305,13 @@ class JiugeForCauslLM:
self.config = config
eos_token_id = self.config["eos_token_id"]
self.eos_token_id = [eos_token_id] if type(eos_token_id) == int else eos_token_id
transpose_weight = device != DeviceType.DEVICE_TYPE_ASCEND # y = xW is faster than y=xW^T on Ascend
if "llama" == config["model_type"]:
model = transformers.LlamaForCausalLM.from_pretrained(model_dir_path).cpu().half()
self.meta = JiugeMetaFromLlama(config)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path)
self.weights = JiugeWeightsImpl(
self.meta, LlamaWeightsNaming(), model.state_dict(), ndev=ndev
self.meta, LlamaWeightsNaming(), model.state_dict(), ndev=ndev, transpose_weight=transpose_weight
)
elif "fm9g" == config["model_type"]:
if any(file.suffix == ".safetensors" for file in Path(model_dir_path).iterdir()):
......@@ -309,7 +323,7 @@ class JiugeForCauslLM:
if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config)
self.weights = JiugeWeightsImpl(
self.meta, LlamaWeightsNaming(), state_dict, ndev=ndev
self.meta, LlamaWeightsNaming(), state_dict, ndev=ndev, transpose_weight=transpose_weight
)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_dir_path, trust_remote_code=True
......@@ -323,7 +337,7 @@ class JiugeForCauslLM:
if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config)
self.weights = JiugeWeightsImpl(
self.meta, LlamaWeightsNaming(), state_dict, ndev=ndev
self.meta, LlamaWeightsNaming(), state_dict, ndev=ndev, transpose_weight=transpose_weight
)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_dir_path, trust_remote_code=True
......@@ -335,7 +349,7 @@ class JiugeForCauslLM:
if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config)
self.weights = JiugeWeightsImpl(
self.meta, LlamaWeightsNaming(), state_dict, ndev=ndev
self.meta, LlamaWeightsNaming(), state_dict, ndev=ndev, transpose_weight=transpose_weight
)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_dir_path
......
......@@ -58,6 +58,7 @@ class JiugeWeights(ctypes.Structure):
("nlayer", c_size_t),
("dt_norm", DataType),
("dt_mat", DataType),
("transpose_linear_weights", c_int),
("input_embd", c_void_p),
("output_norm", c_void_p),
("output_embd", c_void_p),
......
......@@ -21,9 +21,14 @@ inline std::shared_ptr<Tensor> getOutNorm(
inline std::shared_ptr<Tensor> getOutEmbd(
JiugeMeta const *meta,
JiugeWeights const *w) {
auto shape = std::vector<size_t>({meta->dvoc, meta->d});
return Tensor::weight((char *)w->output_embd, meta->dt_logits, shape)
->permute({1, 0});
if (w->transpose_linear_weights != 0) {
auto shape = std::vector<size_t>({meta->dvoc, meta->d});
return Tensor::weight((char *)w->output_embd, meta->dt_logits, shape)
->permute({1, 0});
} else {
auto shape = std::vector<size_t>({meta->d, meta->dvoc});
return Tensor::weight((char *)w->output_embd, meta->dt_logits, shape);
}
}
inline std::shared_ptr<Tensor> getAttnNorm(
......@@ -43,9 +48,14 @@ inline std::shared_ptr<Tensor> getAttnQKV(
auto dh = meta->dh;
auto d = meta->d;
size_t offset = idev * ((nkvh * 2 + nh) / ndev * dh) * d * dsize(w->dt_mat);
auto shape = std::vector<size_t>({(nh + 2 * nkvh) / ndev * dh, d});
return Tensor::weight((char *)(w->attn_qkv[layer]) + offset, w->dt_mat, shape)
->permute({1, 0});
if (w->transpose_linear_weights != 0) {
auto shape = std::vector<size_t>({(nh + 2 * nkvh) / ndev * dh, d});
return Tensor::weight((char *)(w->attn_qkv[layer]) + offset, w->dt_mat, shape)
->permute({1, 0});
} else {
auto shape = std::vector<size_t>({d, (nh + 2 * nkvh) / ndev * dh});
return Tensor::weight((char *)(w->attn_qkv[layer]) + offset, w->dt_mat, shape);
}
}
inline std::shared_ptr<Tensor> getAttnQKVBias(
......@@ -67,9 +77,14 @@ inline std::shared_ptr<Tensor> getAttnO(JiugeMeta const *meta,
auto dh = meta->dh;
auto d = meta->d;
size_t offset = idev * d * (nh / ndev * dh) * dsize(w->dt_mat);
auto shape = std::vector<size_t>({d, nh / ndev * dh});
return Tensor::weight((char *)(w->attn_o[layer]) + offset, w->dt_mat, shape)
->permute({1, 0});
if (w->transpose_linear_weights != 0) {
auto shape = std::vector<size_t>({d, nh / ndev * dh});
return Tensor::weight((char *)(w->attn_o[layer]) + offset, w->dt_mat, shape)
->permute({1, 0});
} else {
auto shape = std::vector<size_t>({nh / ndev * dh, d});
return Tensor::weight((char *)(w->attn_o[layer]) + offset, w->dt_mat, shape);
}
}
inline std::shared_ptr<Tensor> getFFNNorm(
......@@ -87,10 +102,16 @@ inline std::shared_ptr<Tensor> getFFNGateUp(
auto di = meta->di;
auto d = meta->d;
size_t offset = idev * (2 * di / ndev) * d * dsize(w->dt_mat);
auto shape = std::vector<size_t>({2 * di / ndev, d});
return Tensor::weight((char *)(w->ffn_gate_up[layer]) + offset,
w->dt_mat, shape)
->permute({1, 0});
if (w->transpose_linear_weights != 0) {
auto shape = std::vector<size_t>({2 * di / ndev, d});
return Tensor::weight((char *)(w->ffn_gate_up[layer]) + offset,
w->dt_mat, shape)
->permute({1, 0});
} else {
auto shape = std::vector<size_t>({d, 2 * di / ndev});
return Tensor::weight((char *)(w->ffn_gate_up[layer]) + offset,
w->dt_mat, shape);
}
}
inline std::shared_ptr<Tensor> getFFNDown(
......@@ -100,9 +121,14 @@ inline std::shared_ptr<Tensor> getFFNDown(
auto di = meta->di;
auto d = meta->d;
size_t offset = idev * d * (di / ndev) * dsize(w->dt_mat);
auto shape = std::vector<size_t>({d, di / ndev});
return Tensor::weight((char *)(w->ffn_down[layer]) + offset, w->dt_mat, shape)
->permute({1, 0});
if (w->transpose_linear_weights != 0) {
auto shape = std::vector<size_t>({d, di / ndev});
return Tensor::weight((char *)(w->ffn_down[layer]) + offset, w->dt_mat, shape)
->permute({1, 0});
} else {
auto shape = std::vector<size_t>({di / ndev, d});
return Tensor::weight((char *)(w->ffn_down[layer]) + offset, w->dt_mat, shape);
}
}
inline std::shared_ptr<Tensor> getSinTable(JiugeMeta const *meta) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment