Commit 5540d53a authored by PanZezhong's avatar PanZezhong
Browse files

Add workspace allocator

parent 967bcb64
......@@ -11,7 +11,7 @@ struct JiugeModel;
typedef struct
{
infiniDtype_t dt_logits, dt_norm, dt_mat;
infiniDtype_t dt_logits;
size_t nlayer, d, nh, nkvh, dh, di, dctx, dvoc;
float epsilon, theta;
uint32_t end_token;
......@@ -20,6 +20,7 @@ typedef struct
typedef struct
{
size_t nlayer;
infiniDtype_t dt_norm, dt_mat;
// [dvoc, d]
const void *input_embd;
// [d]
......
......@@ -73,11 +73,15 @@ class LlamaWeightsNaming:
class JiugeMetaFromLlama(JiugeMeta):
def __init__(self, config, infini_dtype):
def __init__(self, config, dtype = torch.float16):
if dtype == torch.float16:
dt_ = DataType.INFINI_DTYPE_F16
elif dtype == torch.float32:
dt_ = DataType.INFINI_DTYPE_F32
else:
dt_ = DataType.INFINI_DTYPE_F16
super().__init__(
dt_logits=infini_dtype,
dt_norm=infini_dtype,
dt_mat=infini_dtype,
dt_logits=dt_,
nlayer=config.num_hidden_layers,
d=config.hidden_size,
nh=config.num_attention_heads,
......@@ -94,10 +98,11 @@ class JiugeMetaFromLlama(JiugeMeta):
theta=config.rope_theta,
end_token=2,
)
self.torch_dtype_logits = dtype
class JiugeWeightsImpl(JiugeWeights):
def __init__(self, meta, naming, state_dict, ndev=1):
def __init__(self, meta, naming, state_dict, torch_dt_mat = torch.float16, torch_dt_norm = torch.float32, ndev=1):
nlayer = meta.nlayer
nh = meta.nh
nkvh = meta.nkvh
......@@ -108,17 +113,30 @@ class JiugeWeightsImpl(JiugeWeights):
assert nh % ndev == 0
assert nkvh % ndev == 0
assert di % ndev == 0
torch_dt_logits = meta.torch_dtype_logits
if torch_dt_mat == torch.float16:
self.dt_mat = DataType.INFINI_DTYPE_F16
elif torch_dt_mat == torch.float32:
self.dt_mat = DataType.INFINI_DTYPE_F32
else:
raise ValueError("Unsupported proj weight data type")
if torch_dt_norm == torch.float16:
self.dt_norm = DataType.INFINI_DTYPE_F16
elif torch_dt_norm == torch.float32:
self.dt_norm = DataType.INFINI_DTYPE_F32
else:
raise ValueError("Unsupported norm weight data type")
self.nlayer = nlayer
self.input_embd_tensor = state_dict[naming.input_embd()]
self.input_embd_tensor = state_dict[naming.input_embd()].to(torch_dt_logits)
self.input_embd = self.input_embd_tensor.data_ptr()
self.output_norm_tensor = state_dict[naming.output_norm()]
self.output_norm_tensor = state_dict[naming.output_norm()].to(torch_dt_norm)
self.output_norm = self.output_norm_tensor.data_ptr()
self.output_embd_tensor = state_dict[naming.output_embd()]
self.output_embd_tensor = state_dict[naming.output_embd()].to(torch_dt_mat)
self.output_embd = self.output_embd_tensor.data_ptr()
self.attn_norm_tensors = [
state_dict[naming.attn_norm(i)] for i in range(nlayer)
state_dict[naming.attn_norm(i)].to(torch_dt_norm) for i in range(nlayer)
]
self.attn_norm_ptrs = [
self.attn_norm_tensors[i].data_ptr() for i in range(nlayer)
......@@ -146,7 +164,7 @@ class JiugeWeightsImpl(JiugeWeights):
_result.append(_V[_idev * _nkvh : (_idev + 1) * _nkvh, :, :])
return _result
self.qkv_tensor = [torch.concat(qkv_slices(i)) for i in range(nlayer)]
self.qkv_tensor = [torch.concat(qkv_slices(i)).to(torch_dt_mat) for i in range(nlayer)]
self.qkv_tensor_ptrs = [self.qkv_tensor[i].data_ptr() for i in range(nlayer)]
self.attn_qkv = (c_void_p * nlayer)(*self.qkv_tensor_ptrs)
......@@ -172,7 +190,7 @@ class JiugeWeightsImpl(JiugeWeights):
return _result
if naming.attn_q_b(0) in state_dict:
self.qkv_b_tensors = [torch.concat(qkv_b_slices(i)) for i in range(nlayer)]
self.qkv_b_tensors = [torch.concat(qkv_b_slices(i)).to(torch_dt_logits) for i in range(nlayer)]
self.qkv_b_tensor_ptrs = [
self.qkv_b_tensors[i].data_ptr() for i in range(nlayer)
]
......@@ -181,7 +199,7 @@ class JiugeWeightsImpl(JiugeWeights):
self.attn_qkv_b = None
self.attn_o_tensor = [
state_dict[naming.attn_o(i)]
state_dict[naming.attn_o(i)].to(torch_dt_mat)
.reshape([d, ndev, nh // ndev * dh])
.transpose(0, 1)
.contiguous()
......@@ -190,7 +208,7 @@ class JiugeWeightsImpl(JiugeWeights):
self.attn_o_ptrs = [self.attn_o_tensor[i].data_ptr() for i in range(nlayer)]
self.attn_o = (c_void_p * nlayer)(*self.attn_o_ptrs)
self.ffn_norm_tensors = [state_dict[naming.ffn_norm(i)] for i in range(nlayer)]
self.ffn_norm_tensors = [state_dict[naming.ffn_norm(i)].to(torch_dt_norm) for i in range(nlayer)]
self.ffn_norm_ptrs = [
self.ffn_norm_tensors[i].data_ptr() for i in range(nlayer)
]
......@@ -206,12 +224,12 @@ class JiugeWeightsImpl(JiugeWeights):
_result.append(state_dict[naming.up(_i)][_start:_end, :])
return _result
self.gate_up_tensors = [torch.concat(gate_up_slices(i)) for i in range(nlayer)]
self.gate_up_tensors = [torch.concat(gate_up_slices(i)).to(torch_dt_mat) for i in range(nlayer)]
self.gate_up_ptrs = [self.gate_up_tensors[i].data_ptr() for i in range(nlayer)]
self.ffn_gate_up = (c_void_p * nlayer)(*self.gate_up_ptrs)
self.ffn_down_tensor = [
state_dict[naming.down(i)]
state_dict[naming.down(i)].to(torch_dt_mat)
.reshape([d, ndev, di // ndev])
.transpose(0, 1)
.contiguous()
......@@ -223,23 +241,21 @@ class JiugeWeightsImpl(JiugeWeights):
class JiugeForCauslLM:
def __init__(self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1):
def load_all_safetensors_from_dir(dir_path_: str, torch_type=torch.float16):
def load_all_safetensors_from_dir(dir_path_: str):
tensors_ = {}
dir_path_ = Path(dir_path_)
for file in sorted(dir_path_.glob("*.safetensors")):
data_ = safetensors.safe_open(file, "pt")
for name_ in data_.keys():
tensors_[name_] = data_.get_tensor(name_).to(torch_type)
tensors_[name_] = data_.get_tensor(name_)
return tensors_
config = transformers.AutoConfig.from_pretrained(
model_dir_path, trust_remote_code=True
)
if "llama" == config.model_type:
model = transformers.LlamaForCausalLM.from_pretrained(model_dir_path).to(
torch.float16
)
self.meta = JiugeMetaFromLlama(model.config, DataType.INFINI_DTYPE_F16)
model = transformers.LlamaForCausalLM.from_pretrained(model_dir_path).half()
self.meta = JiugeMetaFromLlama(model.config)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path)
self.weights = JiugeWeightsImpl(
self.meta, LlamaWeightsNaming(), model.state_dict(), ndev=ndev
......@@ -247,7 +263,7 @@ class JiugeForCauslLM:
elif "fm9g" == config.model_type:
state_dict = load_all_safetensors_from_dir(model_dir_path)
if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config, DataType.INFINI_DTYPE_F16)
self.meta = JiugeMetaFromLlama(config)
self.weights = JiugeWeightsImpl(
self.meta, LlamaWeightsNaming(), state_dict, ndev=ndev
)
......@@ -308,6 +324,7 @@ class JiugeForCauslLM:
break
output_content += output_str
print(output_str, end="", flush=True)
# print(output_tokens[0])
req_pos[0] = req_pos[0] + ntok
ntok = 1
tokens = (c_uint * ntok)(*output_tokens)
......
......@@ -38,8 +38,6 @@ class DeviceType(ctypes.c_int):
class JiugeMeta(ctypes.Structure):
_fields_ = [
("dt_logits", DataType),
("dt_norm", DataType),
("dt_mat", DataType),
("nlayer", c_size_t),
("d", c_size_t),
("nh", c_size_t),
......@@ -58,6 +56,8 @@ class JiugeMeta(ctypes.Structure):
class JiugeWeights(ctypes.Structure):
_fields_ = [
("nlayer", c_size_t),
("dt_norm", DataType),
("dt_mat", DataType),
("input_embd", c_void_p),
("output_norm", c_void_p),
("output_embd", c_void_p),
......
#ifndef ALLOCATOR_HPP
#define ALLOCATOR_HPP
#include "infinicore_infer.h"
class AllocatorBase {
public:
virtual void *alloc(size_t size) = 0;
virtual void release(void *ptr) = 0;
};
class WorkspaceAllocator : public AllocatorBase {
private:
void *_memory;
size_t _total_size;
size_t _used_size;
size_t _align = 256;
public:
WorkspaceAllocator(size_t intial_size, size_t align = 256);
~WorkspaceAllocator();
void *alloc(size_t size) override;
void release(void *ptr) override;
};
#endif
#include "../allocator.hpp"
#include "../utils.hpp"
inline size_t aligned_size(size_t size_, size_t align) {
return (size_ + align - 1) & ~(align - 1);
}
inline void *allocate(size_t size_) {
void *ptr;
RUN_INFINI(infinirtMalloc(&ptr, size_));
return ptr;
}
WorkspaceAllocator::WorkspaceAllocator(size_t initial_size_, size_t align) {
_align = align;
if (initial_size_ > 0) {
_total_size = aligned_size(initial_size_, _align);
_memory = allocate(_total_size);
}
}
void *WorkspaceAllocator::alloc(size_t new_size) {
if (_total_size < new_size) {
if (_total_size != 0) {
RUN_INFINI(infinirtFree(_memory));
}
_total_size = aligned_size(new_size * 3 / 2, _align);
_memory = allocate(_total_size);
}
return _memory;
}
void WorkspaceAllocator::release(void *ptr) {
}
WorkspaceAllocator::~WorkspaceAllocator() {
if (_memory != nullptr) {
RUN_INFINI(infinirtFree(_memory));
}
}
\ No newline at end of file
......@@ -31,7 +31,6 @@ void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta,
b_attn_qkv.push_back(
getAttnQKVBias(meta, weights, layer, idev, ndev));
}
w_attn_out.push_back(
getAttnO(meta, weights, layer, idev, ndev));
w_ffn_norm.push_back(
......@@ -42,7 +41,8 @@ void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta,
getFFNDown(meta, weights, layer, idev, ndev));
}
*rsrc = DeviceResource{device,
*rsrc = DeviceResource{
device,
dev_id,
handle,
getInEmbd(meta, weights),
......@@ -58,10 +58,12 @@ void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta,
w_ffn_gate_up,
w_ffn_down,
stream,
comm};
comm,
std::make_unique<WorkspaceAllocator>(0),
};
}
void inferDeviceBatch(const JiugeMeta &meta, const DeviceResource &rsrc,
void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
uint32_t idev, uint32_t ndev,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
......@@ -75,6 +77,7 @@ void inferDeviceBatch(const JiugeMeta &meta, const DeviceResource &rsrc,
auto dh = meta.dh;
auto d = meta.d;
auto dt_logits = meta.dt_logits;
// std::cout << "dt_logits: " <<(int)dt_logits << std::endl;
auto di = meta.di / ndev;
auto dvoc = meta.dvoc;
auto stream = rsrc.stream;
......@@ -215,12 +218,14 @@ void inferDeviceBatch(const JiugeMeta &meta, const DeviceResource &rsrc,
infiniopRandomSampleDescriptor_t desc_sample;
RUN_INFINI(infiniopCreateRandomSampleDescriptor(
rsrc.handle, &desc_sample,
TensorDesc::create(INFINI_DTYPE_U64, {}, {})->get(),
TensorDesc::create(INFINI_DTYPE_U32, {}, {})->get(),
TensorDesc::create(dt_logits, {dvoc}, {1})->get()));
RUN_INFINI(infiniopGetRandomSampleWorkspaceSize(desc_sample, &temp_size));
workspace_size = std::max(workspace_size, temp_size);
// Allocate workspace
RUN_INFINI(infinirtMallocAsync(&workspace, workspace_size, stream));
workspace = rsrc.workspace_allocator->alloc(workspace_size);
// Compute
for (uint32_t layer = 0; layer < nlayer; layer++) {
// 1. Attention
// rms norm
......@@ -323,11 +328,13 @@ void inferDeviceBatch(const JiugeMeta &meta, const DeviceResource &rsrc,
for (uint32_t req = 0; req < nreq; req++) {
auto seq_len = req_lens[req];
float random_val = std::uniform_real_distribution<float>(0, 1)(gen);
// prob_buf->debug();
RUN_INFINI(infiniopRandomSample(
desc_sample, workspace, workspace_size,
result_buf->data(req),
prob_buf->data(req * dvoc), random_val, topp,
topk, temperature, stream));
// result_buf->debug();
token_offset += seq_len;
}
RUN_INFINI(infinirtStreamSynchronize(stream));
......@@ -350,7 +357,6 @@ void inferDeviceBatch(const JiugeMeta &meta, const DeviceResource &rsrc,
infiniopDestroyRMSNormDescriptor(desc_norm_out);
infiniopDestroyGemmDescriptor(desc_out_embd);
infiniopDestroyRandomSampleDescriptor(desc_sample);
infinirtFree(workspace);
}
__C void
......
......@@ -3,6 +3,7 @@
#include "infinicore_infer.h"
#include "../../allocator.hpp"
#include "../../tensor.hpp"
#include <condition_variable>
......@@ -23,7 +24,10 @@ struct DeviceResource {
w_ffn_norm, w_ffn_gate_up, w_ffn_down;
// Streams
infinirtStream_t stream;
// Communicator
infinicclComm_t comm;
std::unique_ptr<WorkspaceAllocator> workspace_allocator;
};
struct InferState {
......
......@@ -12,8 +12,8 @@ __C struct KVCache *createKVCache(const JiugeModel *model) {
auto kcache = std::vector<std::shared_ptr<Tensor>>();
auto vcache = std::vector<std::shared_ptr<Tensor>>();
for (unsigned int layer = 0; layer < model->meta.nlayer; layer++) {
kcache.push_back(std::move(Tensor::buffer(model->meta.dt_mat, shape)));
vcache.push_back(std::move(Tensor::buffer(model->meta.dt_mat, shape)));
kcache.push_back(std::move(Tensor::buffer(model->meta.dt_logits, shape)));
vcache.push_back(std::move(Tensor::buffer(model->meta.dt_logits, shape)));
}
cache->k.push_back(kcache);
cache->v.push_back(vcache);
......
......@@ -15,7 +15,7 @@ inline std::shared_ptr<Tensor> getOutNorm(
JiugeMeta const *meta,
JiugeWeights const *w) {
auto shape = std::vector<size_t>({meta->d});
return Tensor::weight((char *)w->output_norm, meta->dt_norm, shape);
return Tensor::weight((char *)w->output_norm, w->dt_norm, shape);
}
inline std::shared_ptr<Tensor> getOutEmbd(
......@@ -31,7 +31,7 @@ inline std::shared_ptr<Tensor> getAttnNorm(
JiugeWeights const *w,
size_t layer) {
auto shape = std::vector<size_t>({meta->d});
return Tensor::weight((char *)(w->attn_norm[layer]), meta->dt_norm, shape);
return Tensor::weight((char *)(w->attn_norm[layer]), w->dt_norm, shape);
}
inline std::shared_ptr<Tensor> getAttnQKV(
......@@ -42,9 +42,9 @@ inline std::shared_ptr<Tensor> getAttnQKV(
auto nh = meta->nh;
auto dh = meta->dh;
auto d = meta->d;
size_t offset = idev * ((nkvh * 2 + nh) / ndev * dh) * d * dsize(meta->dt_mat);
size_t offset = idev * ((nkvh * 2 + nh) / ndev * dh) * d * dsize(w->dt_mat);
auto shape = std::vector<size_t>({(nh + 2 * nkvh) / ndev * dh, d});
return Tensor::weight((char *)(w->attn_qkv[layer]) + offset, meta->dt_mat, shape)
return Tensor::weight((char *)(w->attn_qkv[layer]) + offset, w->dt_mat, shape)
->permute({1, 0});
}
......@@ -55,9 +55,9 @@ inline std::shared_ptr<Tensor> getAttnQKVBias(
auto nkvh = meta->nkvh;
auto nh = meta->nh;
auto dh = meta->dh;
size_t offset = idev * ((nkvh * 2 + nh) / ndev * dh) * dsize(meta->dt_mat);
size_t offset = idev * ((nkvh * 2 + nh) / ndev * dh) * dsize(w->dt_mat);
auto shape = std::vector<size_t>({1, (nh + 2 * nkvh) / ndev * dh});
return Tensor::weight((char *)(w->attn_qkv_b[layer]) + offset, meta->dt_mat, shape);
return Tensor::weight((char *)(w->attn_qkv_b[layer]) + offset, w->dt_mat, shape);
}
inline std::shared_ptr<Tensor> getAttnO(JiugeMeta const *meta,
......@@ -66,9 +66,9 @@ inline std::shared_ptr<Tensor> getAttnO(JiugeMeta const *meta,
auto nh = meta->nh;
auto dh = meta->dh;
auto d = meta->d;
size_t offset = idev * d * (nh / ndev * dh) * dsize(meta->dt_mat);
size_t offset = idev * d * (nh / ndev * dh) * dsize(w->dt_mat);
auto shape = std::vector<size_t>({d, nh / ndev * dh});
return Tensor::weight((char *)(w->attn_o[layer]) + offset, meta->dt_mat, shape)
return Tensor::weight((char *)(w->attn_o[layer]) + offset, w->dt_mat, shape)
->permute({1, 0});
}
......@@ -77,7 +77,7 @@ inline std::shared_ptr<Tensor> getFFNNorm(
JiugeWeights const *w,
size_t layer) {
auto shape = std::vector<size_t>({meta->d});
return Tensor::weight((char *)(w->ffn_norm[layer]), meta->dt_norm, shape);
return Tensor::weight((char *)(w->ffn_norm[layer]), w->dt_norm, shape);
}
inline std::shared_ptr<Tensor> getFFNGateUp(
......@@ -86,10 +86,10 @@ inline std::shared_ptr<Tensor> getFFNGateUp(
size_t layer, size_t idev, size_t ndev) {
auto di = meta->di;
auto d = meta->d;
size_t offset = idev * (2 * di / ndev) * d * dsize(meta->dt_mat);
size_t offset = idev * (2 * di / ndev) * d * dsize(w->dt_mat);
auto shape = std::vector<size_t>({2 * di / ndev, d});
return Tensor::weight((char *)(w->ffn_gate_up[layer]) + offset,
meta->dt_mat, shape)
w->dt_mat, shape)
->permute({1, 0});
}
......@@ -99,21 +99,29 @@ inline std::shared_ptr<Tensor> getFFNDown(
size_t layer, size_t idev, size_t ndev) {
auto di = meta->di;
auto d = meta->d;
size_t offset = idev * d * (di / ndev) * dsize(meta->dt_mat);
size_t offset = idev * d * (di / ndev) * dsize(w->dt_mat);
auto shape = std::vector<size_t>({d, di / ndev});
return Tensor::weight((char *)(w->ffn_down[layer]) + offset, meta->dt_mat, shape)
return Tensor::weight((char *)(w->ffn_down[layer]) + offset, w->dt_mat, shape)
->permute({1, 0});
}
inline std::shared_ptr<Tensor> getSinTable(JiugeMeta const *meta) {
auto half_dh = meta->dh / 2;
uint16_t *table = (uint16_t *)std::malloc(meta->dctx * half_dh * sizeof(uint16_t));
auto unit = dsize(meta->dt_logits);
void *table = std::malloc(meta->dctx * half_dh * unit);
for (size_t i = 0; i < meta->dctx; i++) {
for (size_t j = 0; j < half_dh; j++) {
float _sin = std::sin(
static_cast<float>(i) / std::pow(meta->theta, static_cast<float>(j) / half_dh));
table[i * half_dh + j] = f32_to_f16(_sin);
if (meta->dt_logits == INFINI_DTYPE_F16) {
((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_sin);
} else if (meta->dt_logits == INFINI_DTYPE_F32) {
((float *)table)[i * half_dh + j] = _sin;
} else {
std::cout << "unsupported data type" << std::endl;
exit(1);
}
}
}
auto shape = std::vector<size_t>({meta->dctx, half_dh});
......@@ -124,13 +132,21 @@ inline std::shared_ptr<Tensor> getSinTable(JiugeMeta const *meta) {
inline std::shared_ptr<Tensor> getCosTable(JiugeMeta const *meta) {
auto half_dh = meta->dh / 2;
uint16_t *table = (uint16_t *)std::malloc(meta->dctx * half_dh * sizeof(uint16_t));
auto unit = dsize(meta->dt_logits);
void *table = std::malloc(meta->dctx * half_dh * unit);
for (size_t i = 0; i < meta->dctx; i++) {
for (size_t j = 0; j < half_dh; j++) {
float _cos = std::cos(
static_cast<float>(i) / std::pow(meta->theta, static_cast<float>(j) / half_dh));
table[i * half_dh + j] = f32_to_f16(_cos);
if (meta->dt_logits == INFINI_DTYPE_F16) {
((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_cos);
} else if (meta->dt_logits == INFINI_DTYPE_F32) {
((float *)table)[i * half_dh + j] = _cos;
} else {
std::cout << "unsupported data type" << std::endl;
exit(1);
}
}
}
auto shape = std::vector<size_t>({meta->dctx, half_dh});
......
......@@ -12,6 +12,7 @@ target("infinicore_infer")
add_files("src/models/*/*.cpp")
add_files("src/tensor/*.cpp")
add_files("src/allocator/*.cpp")
add_includedirs("include")
set_installdir(INFINI_ROOT)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment