Commit e8e83308 authored by qiyuxinlin's avatar qiyuxinlin
Browse files

fix flashinfer float_workspace_buffer small

parent 02948bc1
...@@ -195,13 +195,13 @@ class Engine: ...@@ -195,13 +195,13 @@ class Engine:
self.block_num = inference_context.k_cache[0].size(1) self.block_num = inference_context.k_cache[0].size(1)
self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num)
#@TODO add config #@TODO add config
if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM": if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM":
self.model.init_wrapper(self.args.use_cuda_graph, self.device, Config().chunk_size, args.max_batch_size, self.block_num) # TODO: 1024 is a magic number(max_batch_tokens) self.model.init_wrapper(self.args.use_cuda_graph, self.device, max(self.model_runner.cuda_graphs), args.max_batch_size, self.block_num)
else: else:
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num) self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num)
self.sampler = Sampler() self.sampler = Sampler()
self.query_manager = QueryManager(device = self.device, page_size = args.page_size) self.query_manager = QueryManager(device = self.device, page_size = args.page_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment