Commit 13f98ed3 authored by PanZezhong's avatar PanZezhong
Browse files

fix workspace

parent 366d3aef
......@@ -13,6 +13,7 @@ from libinfinicore_infer import (
DataType,
DeviceType,
create_jiuge_model,
destroy_jiuge_model,
create_kv_cache,
drop_kv_cache,
infer_batch,
......@@ -282,6 +283,9 @@ class JiugeForCauslLM:
tensors_[name_] = data_.get_tensor(name_)
return tensors_
print("Loading model weights to host...")
load_start_time = time.time()
with open(os.path.join(model_dir_path, "config.json"), "r") as f:
config = json.load(f)
......@@ -293,7 +297,12 @@ class JiugeForCauslLM:
self.meta, LlamaWeightsNaming(), model.state_dict(), ndev=ndev
)
elif "fm9g" == config["model_type"]:
if any(file.suffix == ".safetensors" for file in Path(model_dir_path).iterdir()):
state_dict = load_all_safetensors_from_dir(model_dir_path)
else:
state_dict = torch.load(
os.path.join(model_dir_path, "pytorch_model.bin"), weights_only=True, map_location="cpu"
)
if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config)
self.weights = JiugeWeightsImpl(
......@@ -317,6 +326,12 @@ class JiugeForCauslLM:
else:
raise ValueError("Unsupported model architecture")
load_end_time = time.time()
print(f"Time used: {load_end_time - load_start_time:.3f}s")
print(f"Creating model on {ndev} devices...")
load_start_time = time.time()
dev_ids = (c_int * ndev)(*[i for i in range(ndev)])
self.model_instance = create_jiuge_model(
byref(self.meta),
......@@ -325,18 +340,21 @@ class JiugeForCauslLM:
ndev,
dev_ids,
)
load_end_time = time.time()
print(f"Time used: {load_end_time - load_start_time:.3f}s")
def infer(self, input_list, topp=1.0, topk=1, temperature=1.0):
pass
def generate(self, input_content, max_steps, topp=1.0, topk=1, temperature=1.0):
kv_cache = create_kv_cache(self.model_instance)
input_content = self.tokenizer.apply_chat_template(
conversation=[{"role": "user", "content": input_content}],
add_generation_prompt=True,
tokenize=False,
)
print(input_content, end="", flush=True)
kv_cache = create_kv_cache(self.model_instance)
tokens = self.tokenizer.encode(input_content)
ntok = len(tokens)
nreq = 1
......@@ -367,6 +385,7 @@ class JiugeForCauslLM:
)
steps += 1
output_tokens = list(ans)
end_time = time.time()
output_str = (
self.tokenizer._tokenizer.id_to_token(output_tokens[0])
.replace("▁", " ")
......@@ -380,7 +399,7 @@ class JiugeForCauslLM:
ntok = 1
tokens = (c_uint * ntok)(*output_tokens)
req_lens = (c_uint * nreq)(*[ntok])
end_time = time.time()
if step_i > 0:
total_time += end_time - start_time
......@@ -391,6 +410,10 @@ class JiugeForCauslLM:
drop_kv_cache(self.model_instance, kv_cache)
return output_content, avg_time
def destroy_model_instance(self):
destroy_jiuge_model(self.model_instance)
print("Model destroyed")
def test():
if len(sys.argv) < 3:
......@@ -421,6 +444,7 @@ def test():
ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1
model = JiugeForCauslLM(model_path, device_type, ndev)
model.generate("山东最高的山是?", 500)
model.destroy_model_instance()
if __name__ == "__main__":
......
......@@ -92,12 +92,12 @@ def __open_library__():
c_int, # int ndev
POINTER(c_int), # int const *dev_ids
]
lib.destroyJiugeModel.argtypes = [POINTER(JiugeModel)]
lib.createKVCache.restype = POINTER(KVCache)
lib.dropKVCache.argtypes = [ctypes.POINTER(JiugeModel), POINTER(KVCache)]
lib.dropKVCache.argtypes = [POINTER(JiugeModel), POINTER(KVCache)]
lib.inferBatch.restype = None
lib.inferBatch.argtypes = [
ctypes.POINTER(JiugeModel), # struct JiugeModel const *
POINTER(JiugeModel), # struct JiugeModel const *
POINTER(c_uint), # unsigned int const *tokens
c_uint, # unsigned int ntok
POINTER(c_uint), # unsigned int const *req_lens
......@@ -116,6 +116,7 @@ def __open_library__():
LIB = __open_library__()
create_jiuge_model = LIB.createJiugeModel
destroy_jiuge_model = LIB.destroyJiugeModel
create_kv_cache = LIB.createKVCache
drop_kv_cache = LIB.dropKVCache
infer_batch = LIB.inferBatch
......@@ -13,8 +13,7 @@ class WorkspaceAllocator : public AllocatorBase {
private:
void *_memory;
size_t _total_size;
size_t _used_size;
size_t _align = 256;
size_t _align;
public:
WorkspaceAllocator(size_t intial_size, size_t align = 256);
......
......@@ -14,6 +14,8 @@ inline void *allocate(size_t size_) {
WorkspaceAllocator::WorkspaceAllocator(size_t initial_size_, size_t align) {
_align = align;
_total_size = 0;
_memory = nullptr;
if (initial_size_ > 0) {
_total_size = aligned_size(initial_size_, _align);
_memory = allocate(_total_size);
......@@ -23,9 +25,10 @@ WorkspaceAllocator::WorkspaceAllocator(size_t initial_size_, size_t align) {
void *WorkspaceAllocator::alloc(size_t new_size) {
if (_total_size < new_size) {
if (_total_size != 0) {
RUN_INFINI(infinirtDeviceSynchronize());
RUN_INFINI(infinirtFree(_memory));
}
_total_size = aligned_size(new_size * 3 / 2, _align);
_total_size = aligned_size(new_size, _align);
_memory = allocate(_total_size);
}
return _memory;
......@@ -36,6 +39,7 @@ void WorkspaceAllocator::release(void *ptr) {
WorkspaceAllocator::~WorkspaceAllocator() {
if (_memory != nullptr) {
RUN_INFINI(infinirtDeviceSynchronize());
RUN_INFINI(infinirtFree(_memory));
}
}
\ No newline at end of file
......@@ -61,6 +61,52 @@ void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta,
comm,
std::make_unique<WorkspaceAllocator>(0),
};
RUN_INFINI(infinirtDeviceSynchronize());
}
void releaseDeviceResource(DeviceResource &res) {
infinirtDeviceSynchronize();
// Release individual Tensors
res.w_in_embd.reset();
res.w_out_norm.reset();
res.w_out_embd.reset();
res.sin_table.reset();
res.cos_table.reset();
for (auto &t : res.w_attn_norm) {
t.reset();
}
res.w_attn_norm.clear();
for (auto &t : res.w_attn_qkv) {
t.reset();
}
res.w_attn_qkv.clear();
for (auto &t : res.b_attn_qkv) {
t.reset();
}
res.b_attn_qkv.clear();
for (auto &t : res.w_attn_out) {
t.reset();
}
res.w_attn_out.clear();
for (auto &t : res.w_ffn_norm) {
t.reset();
}
res.w_ffn_norm.clear();
for (auto &t : res.w_ffn_gate_up) {
t.reset();
}
res.w_ffn_gate_up.clear();
for (auto &t : res.w_ffn_down) {
t.reset();
}
res.w_ffn_down.clear();
res.workspace_allocator.reset();
infiniopDestroyHandle(res.handle);
res.handle = nullptr;
infinirtStreamDestroy(res.stream);
res.stream = nullptr;
infinicclCommDestroy(res.comm);
res.comm = nullptr;
}
void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
......@@ -291,6 +337,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
RUN_INFINI(infinicclAllReduce(
logits_in->data(), logits_in->data(), ntok * d, dt_logits,
INFINICCL_SUM, rsrc.comm, stream));
RUN_INFINI(infinirtStreamSynchronize(stream));
}
// 2. FFN
// rms_norm
......@@ -315,6 +362,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
RUN_INFINI(infinicclAllReduce(
logits_in->data(), logits_in->data(), ntok * d, dt_logits,
INFINICCL_SUM, rsrc.comm, stream));
RUN_INFINI(infinirtStreamSynchronize(stream));
}
}
// Sample and Output
......@@ -408,10 +456,20 @@ inferBatch(struct JiugeModel *model,
void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceResource *rsrc, InferState &state, InferRequest &req,
infiniDevice_t device, int idev, int ndev, int dev_id, infinicclComm_t comm) {
// Create Device Resource
createDeviceResource(rsrc, &meta, weights, device, idev, ndev, dev_id, comm);
{
std::unique_lock<std::mutex> lock(state.mtx);
state.loaded = true;
lock.unlock();
state.cv_load.notify_one();
}
// Infer Loop
while (true) {
std::unique_lock<std::mutex> lock(state.mtx);
state.cv_start.wait(lock, [&] { return state.proceed || state.exit_flag; });
// quit if exit_flag is set
if (state.exit_flag) {
break;
}
......@@ -423,9 +481,8 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceReso
state.cv_done.notify_one();
}
infiniopDestroyHandle(rsrc->handle);
infinirtStreamDestroy(rsrc->stream);
infinicclCommDestroy(rsrc->comm);
// Clean-Up
releaseDeviceResource(*rsrc);
}
JiugeModel::JiugeModel(const JiugeMeta *_meta, const JiugeWeights *weights, infiniDevice_t device_, std::vector<int> device_ids) : meta(*_meta) {
......@@ -444,6 +501,11 @@ JiugeModel::JiugeModel(const JiugeMeta *_meta, const JiugeWeights *weights, infi
for (int i = 0; i < ndev; i++) {
threads[i] = std::thread(launchDevice, std::cref(meta), weights, &dev_resources[i], std::ref(states[i]), std::ref(req), device, i, ndev, dev_ids[i], comms[i]);
}
for (int i = 0; i < ndev; i++) {
std::unique_lock<std::mutex> lock(states[i].mtx);
states[i].cv_load.wait(lock, [&] { return states[i].loaded; });
lock.unlock();
}
}
__C struct JiugeModel *
......
......@@ -32,7 +32,8 @@ struct DeviceResource {
struct InferState {
std::mutex mtx;
std::condition_variable cv_start, cv_done;
std::condition_variable cv_load, cv_start, cv_done;
bool loaded = false;
bool proceed = false;
bool exit_flag = false;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment