"vscode:/vscode.git/clone" did not exist on "89615203bc456b47a40dd9768940f100d5cf8846"
Commit 13f98ed3 authored by PanZezhong's avatar PanZezhong
Browse files

fix workspace

parent 366d3aef
...@@ -13,6 +13,7 @@ from libinfinicore_infer import ( ...@@ -13,6 +13,7 @@ from libinfinicore_infer import (
DataType, DataType,
DeviceType, DeviceType,
create_jiuge_model, create_jiuge_model,
destroy_jiuge_model,
create_kv_cache, create_kv_cache,
drop_kv_cache, drop_kv_cache,
infer_batch, infer_batch,
...@@ -282,6 +283,9 @@ class JiugeForCauslLM: ...@@ -282,6 +283,9 @@ class JiugeForCauslLM:
tensors_[name_] = data_.get_tensor(name_) tensors_[name_] = data_.get_tensor(name_)
return tensors_ return tensors_
print("Loading model weights to host...")
load_start_time = time.time()
with open(os.path.join(model_dir_path, "config.json"), "r") as f: with open(os.path.join(model_dir_path, "config.json"), "r") as f:
config = json.load(f) config = json.load(f)
...@@ -293,7 +297,12 @@ class JiugeForCauslLM: ...@@ -293,7 +297,12 @@ class JiugeForCauslLM:
self.meta, LlamaWeightsNaming(), model.state_dict(), ndev=ndev self.meta, LlamaWeightsNaming(), model.state_dict(), ndev=ndev
) )
elif "fm9g" == config["model_type"]: elif "fm9g" == config["model_type"]:
if any(file.suffix == ".safetensors" for file in Path(model_dir_path).iterdir()):
state_dict = load_all_safetensors_from_dir(model_dir_path) state_dict = load_all_safetensors_from_dir(model_dir_path)
else:
state_dict = torch.load(
os.path.join(model_dir_path, "pytorch_model.bin"), weights_only=True, map_location="cpu"
)
if LlamaWeightsNaming.match(state_dict): if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config) self.meta = JiugeMetaFromLlama(config)
self.weights = JiugeWeightsImpl( self.weights = JiugeWeightsImpl(
...@@ -317,6 +326,12 @@ class JiugeForCauslLM: ...@@ -317,6 +326,12 @@ class JiugeForCauslLM:
else: else:
raise ValueError("Unsupported model architecture") raise ValueError("Unsupported model architecture")
load_end_time = time.time()
print(f"Time used: {load_end_time - load_start_time:.3f}s")
print(f"Creating model on {ndev} devices...")
load_start_time = time.time()
dev_ids = (c_int * ndev)(*[i for i in range(ndev)]) dev_ids = (c_int * ndev)(*[i for i in range(ndev)])
self.model_instance = create_jiuge_model( self.model_instance = create_jiuge_model(
byref(self.meta), byref(self.meta),
...@@ -325,18 +340,21 @@ class JiugeForCauslLM: ...@@ -325,18 +340,21 @@ class JiugeForCauslLM:
ndev, ndev,
dev_ids, dev_ids,
) )
load_end_time = time.time()
print(f"Time used: {load_end_time - load_start_time:.3f}s")
def infer(self, input_list, topp=1.0, topk=1, temperature=1.0): def infer(self, input_list, topp=1.0, topk=1, temperature=1.0):
pass pass
def generate(self, input_content, max_steps, topp=1.0, topk=1, temperature=1.0): def generate(self, input_content, max_steps, topp=1.0, topk=1, temperature=1.0):
kv_cache = create_kv_cache(self.model_instance)
input_content = self.tokenizer.apply_chat_template( input_content = self.tokenizer.apply_chat_template(
conversation=[{"role": "user", "content": input_content}], conversation=[{"role": "user", "content": input_content}],
add_generation_prompt=True, add_generation_prompt=True,
tokenize=False, tokenize=False,
) )
print(input_content, end="", flush=True) print(input_content, end="", flush=True)
kv_cache = create_kv_cache(self.model_instance)
tokens = self.tokenizer.encode(input_content) tokens = self.tokenizer.encode(input_content)
ntok = len(tokens) ntok = len(tokens)
nreq = 1 nreq = 1
...@@ -367,6 +385,7 @@ class JiugeForCauslLM: ...@@ -367,6 +385,7 @@ class JiugeForCauslLM:
) )
steps += 1 steps += 1
output_tokens = list(ans) output_tokens = list(ans)
end_time = time.time()
output_str = ( output_str = (
self.tokenizer._tokenizer.id_to_token(output_tokens[0]) self.tokenizer._tokenizer.id_to_token(output_tokens[0])
.replace("▁", " ") .replace("▁", " ")
...@@ -380,7 +399,7 @@ class JiugeForCauslLM: ...@@ -380,7 +399,7 @@ class JiugeForCauslLM:
ntok = 1 ntok = 1
tokens = (c_uint * ntok)(*output_tokens) tokens = (c_uint * ntok)(*output_tokens)
req_lens = (c_uint * nreq)(*[ntok]) req_lens = (c_uint * nreq)(*[ntok])
end_time = time.time()
if step_i > 0: if step_i > 0:
total_time += end_time - start_time total_time += end_time - start_time
...@@ -391,6 +410,10 @@ class JiugeForCauslLM: ...@@ -391,6 +410,10 @@ class JiugeForCauslLM:
drop_kv_cache(self.model_instance, kv_cache) drop_kv_cache(self.model_instance, kv_cache)
return output_content, avg_time return output_content, avg_time
def destroy_model_instance(self):
destroy_jiuge_model(self.model_instance)
print("Model destroyed")
def test(): def test():
if len(sys.argv) < 3: if len(sys.argv) < 3:
...@@ -421,6 +444,7 @@ def test(): ...@@ -421,6 +444,7 @@ def test():
ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1 ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1
model = JiugeForCauslLM(model_path, device_type, ndev) model = JiugeForCauslLM(model_path, device_type, ndev)
model.generate("山东最高的山是?", 500) model.generate("山东最高的山是?", 500)
model.destroy_model_instance()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -92,12 +92,12 @@ def __open_library__(): ...@@ -92,12 +92,12 @@ def __open_library__():
c_int, # int ndev c_int, # int ndev
POINTER(c_int), # int const *dev_ids POINTER(c_int), # int const *dev_ids
] ]
lib.destroyJiugeModel.argtypes = [POINTER(JiugeModel)]
lib.createKVCache.restype = POINTER(KVCache) lib.createKVCache.restype = POINTER(KVCache)
lib.dropKVCache.argtypes = [ctypes.POINTER(JiugeModel), POINTER(KVCache)] lib.dropKVCache.argtypes = [POINTER(JiugeModel), POINTER(KVCache)]
lib.inferBatch.restype = None lib.inferBatch.restype = None
lib.inferBatch.argtypes = [ lib.inferBatch.argtypes = [
ctypes.POINTER(JiugeModel), # struct JiugeModel const * POINTER(JiugeModel), # struct JiugeModel const *
POINTER(c_uint), # unsigned int const *tokens POINTER(c_uint), # unsigned int const *tokens
c_uint, # unsigned int ntok c_uint, # unsigned int ntok
POINTER(c_uint), # unsigned int const *req_lens POINTER(c_uint), # unsigned int const *req_lens
...@@ -116,6 +116,7 @@ def __open_library__(): ...@@ -116,6 +116,7 @@ def __open_library__():
LIB = __open_library__() LIB = __open_library__()
create_jiuge_model = LIB.createJiugeModel create_jiuge_model = LIB.createJiugeModel
destroy_jiuge_model = LIB.destroyJiugeModel
create_kv_cache = LIB.createKVCache create_kv_cache = LIB.createKVCache
drop_kv_cache = LIB.dropKVCache drop_kv_cache = LIB.dropKVCache
infer_batch = LIB.inferBatch infer_batch = LIB.inferBatch
...@@ -13,8 +13,7 @@ class WorkspaceAllocator : public AllocatorBase { ...@@ -13,8 +13,7 @@ class WorkspaceAllocator : public AllocatorBase {
private: private:
void *_memory; void *_memory;
size_t _total_size; size_t _total_size;
size_t _used_size; size_t _align;
size_t _align = 256;
public: public:
WorkspaceAllocator(size_t intial_size, size_t align = 256); WorkspaceAllocator(size_t intial_size, size_t align = 256);
......
...@@ -14,6 +14,8 @@ inline void *allocate(size_t size_) { ...@@ -14,6 +14,8 @@ inline void *allocate(size_t size_) {
WorkspaceAllocator::WorkspaceAllocator(size_t initial_size_, size_t align) { WorkspaceAllocator::WorkspaceAllocator(size_t initial_size_, size_t align) {
_align = align; _align = align;
_total_size = 0;
_memory = nullptr;
if (initial_size_ > 0) { if (initial_size_ > 0) {
_total_size = aligned_size(initial_size_, _align); _total_size = aligned_size(initial_size_, _align);
_memory = allocate(_total_size); _memory = allocate(_total_size);
...@@ -23,9 +25,10 @@ WorkspaceAllocator::WorkspaceAllocator(size_t initial_size_, size_t align) { ...@@ -23,9 +25,10 @@ WorkspaceAllocator::WorkspaceAllocator(size_t initial_size_, size_t align) {
void *WorkspaceAllocator::alloc(size_t new_size) { void *WorkspaceAllocator::alloc(size_t new_size) {
if (_total_size < new_size) { if (_total_size < new_size) {
if (_total_size != 0) { if (_total_size != 0) {
RUN_INFINI(infinirtDeviceSynchronize());
RUN_INFINI(infinirtFree(_memory)); RUN_INFINI(infinirtFree(_memory));
} }
_total_size = aligned_size(new_size * 3 / 2, _align); _total_size = aligned_size(new_size, _align);
_memory = allocate(_total_size); _memory = allocate(_total_size);
} }
return _memory; return _memory;
...@@ -36,6 +39,7 @@ void WorkspaceAllocator::release(void *ptr) { ...@@ -36,6 +39,7 @@ void WorkspaceAllocator::release(void *ptr) {
WorkspaceAllocator::~WorkspaceAllocator() { WorkspaceAllocator::~WorkspaceAllocator() {
if (_memory != nullptr) { if (_memory != nullptr) {
RUN_INFINI(infinirtDeviceSynchronize());
RUN_INFINI(infinirtFree(_memory)); RUN_INFINI(infinirtFree(_memory));
} }
} }
\ No newline at end of file
...@@ -61,6 +61,52 @@ void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta, ...@@ -61,6 +61,52 @@ void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta,
comm, comm,
std::make_unique<WorkspaceAllocator>(0), std::make_unique<WorkspaceAllocator>(0),
}; };
RUN_INFINI(infinirtDeviceSynchronize());
}
void releaseDeviceResource(DeviceResource &res) {
infinirtDeviceSynchronize();
// Release individual Tensors
res.w_in_embd.reset();
res.w_out_norm.reset();
res.w_out_embd.reset();
res.sin_table.reset();
res.cos_table.reset();
for (auto &t : res.w_attn_norm) {
t.reset();
}
res.w_attn_norm.clear();
for (auto &t : res.w_attn_qkv) {
t.reset();
}
res.w_attn_qkv.clear();
for (auto &t : res.b_attn_qkv) {
t.reset();
}
res.b_attn_qkv.clear();
for (auto &t : res.w_attn_out) {
t.reset();
}
res.w_attn_out.clear();
for (auto &t : res.w_ffn_norm) {
t.reset();
}
res.w_ffn_norm.clear();
for (auto &t : res.w_ffn_gate_up) {
t.reset();
}
res.w_ffn_gate_up.clear();
for (auto &t : res.w_ffn_down) {
t.reset();
}
res.w_ffn_down.clear();
res.workspace_allocator.reset();
infiniopDestroyHandle(res.handle);
res.handle = nullptr;
infinirtStreamDestroy(res.stream);
res.stream = nullptr;
infinicclCommDestroy(res.comm);
res.comm = nullptr;
} }
void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...@@ -291,6 +337,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -291,6 +337,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
RUN_INFINI(infinicclAllReduce( RUN_INFINI(infinicclAllReduce(
logits_in->data(), logits_in->data(), ntok * d, dt_logits, logits_in->data(), logits_in->data(), ntok * d, dt_logits,
INFINICCL_SUM, rsrc.comm, stream)); INFINICCL_SUM, rsrc.comm, stream));
RUN_INFINI(infinirtStreamSynchronize(stream));
} }
// 2. FFN // 2. FFN
// rms_norm // rms_norm
...@@ -315,6 +362,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -315,6 +362,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
RUN_INFINI(infinicclAllReduce( RUN_INFINI(infinicclAllReduce(
logits_in->data(), logits_in->data(), ntok * d, dt_logits, logits_in->data(), logits_in->data(), ntok * d, dt_logits,
INFINICCL_SUM, rsrc.comm, stream)); INFINICCL_SUM, rsrc.comm, stream));
RUN_INFINI(infinirtStreamSynchronize(stream));
} }
} }
// Sample and Output // Sample and Output
...@@ -408,10 +456,20 @@ inferBatch(struct JiugeModel *model, ...@@ -408,10 +456,20 @@ inferBatch(struct JiugeModel *model,
void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceResource *rsrc, InferState &state, InferRequest &req, void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceResource *rsrc, InferState &state, InferRequest &req,
infiniDevice_t device, int idev, int ndev, int dev_id, infinicclComm_t comm) { infiniDevice_t device, int idev, int ndev, int dev_id, infinicclComm_t comm) {
// Create Device Resource
createDeviceResource(rsrc, &meta, weights, device, idev, ndev, dev_id, comm); createDeviceResource(rsrc, &meta, weights, device, idev, ndev, dev_id, comm);
{
std::unique_lock<std::mutex> lock(state.mtx);
state.loaded = true;
lock.unlock();
state.cv_load.notify_one();
}
// Infer Loop
while (true) { while (true) {
std::unique_lock<std::mutex> lock(state.mtx); std::unique_lock<std::mutex> lock(state.mtx);
state.cv_start.wait(lock, [&] { return state.proceed || state.exit_flag; }); state.cv_start.wait(lock, [&] { return state.proceed || state.exit_flag; });
// quit if exit_flag is set
if (state.exit_flag) { if (state.exit_flag) {
break; break;
} }
...@@ -423,9 +481,8 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceReso ...@@ -423,9 +481,8 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceReso
state.cv_done.notify_one(); state.cv_done.notify_one();
} }
infiniopDestroyHandle(rsrc->handle); // Clean-Up
infinirtStreamDestroy(rsrc->stream); releaseDeviceResource(*rsrc);
infinicclCommDestroy(rsrc->comm);
} }
JiugeModel::JiugeModel(const JiugeMeta *_meta, const JiugeWeights *weights, infiniDevice_t device_, std::vector<int> device_ids) : meta(*_meta) { JiugeModel::JiugeModel(const JiugeMeta *_meta, const JiugeWeights *weights, infiniDevice_t device_, std::vector<int> device_ids) : meta(*_meta) {
...@@ -444,6 +501,11 @@ JiugeModel::JiugeModel(const JiugeMeta *_meta, const JiugeWeights *weights, infi ...@@ -444,6 +501,11 @@ JiugeModel::JiugeModel(const JiugeMeta *_meta, const JiugeWeights *weights, infi
for (int i = 0; i < ndev; i++) { for (int i = 0; i < ndev; i++) {
threads[i] = std::thread(launchDevice, std::cref(meta), weights, &dev_resources[i], std::ref(states[i]), std::ref(req), device, i, ndev, dev_ids[i], comms[i]); threads[i] = std::thread(launchDevice, std::cref(meta), weights, &dev_resources[i], std::ref(states[i]), std::ref(req), device, i, ndev, dev_ids[i], comms[i]);
} }
for (int i = 0; i < ndev; i++) {
std::unique_lock<std::mutex> lock(states[i].mtx);
states[i].cv_load.wait(lock, [&] { return states[i].loaded; });
lock.unlock();
}
} }
__C struct JiugeModel * __C struct JiugeModel *
......
...@@ -32,7 +32,8 @@ struct DeviceResource { ...@@ -32,7 +32,8 @@ struct DeviceResource {
struct InferState { struct InferState {
std::mutex mtx; std::mutex mtx;
std::condition_variable cv_start, cv_done; std::condition_variable cv_load, cv_start, cv_done;
bool loaded = false;
bool proceed = false; bool proceed = false;
bool exit_flag = false; bool exit_flag = false;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment