Commit f5f79f5c authored by chenxl's avatar chenxl
Browse files

[ADD] support multi-gpu qlen>1 q5_k

parent f2938031
...@@ -21,6 +21,7 @@ class CUDAGraphRunner: ...@@ -21,6 +21,7 @@ class CUDAGraphRunner:
position_ids, position_ids,
cache_position, cache_position,
past_key_values, past_key_values,
main_device,
**kwargs, **kwargs,
) -> None: ) -> None:
assert self.graph is None assert self.graph is None
...@@ -29,15 +30,24 @@ class CUDAGraphRunner: ...@@ -29,15 +30,24 @@ class CUDAGraphRunner:
self.graph = torch.cuda.CUDAGraph() self.graph = torch.cuda.CUDAGraph()
#self.graph.enable_debug_mode() #self.graph.enable_debug_mode()
self.model = model self.model = model
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to("cuda") inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(main_device)
with torch.cuda.graph(self.graph): # torch.cuda.set_device can't set "cuda", must have a index
if main_device == "cuda":
main_device = "cuda:0"
torch.cuda.set_device(main_device)
self.main_device = main_device
capture_stream = torch.cuda.Stream()
with torch.cuda.graph(self.graph, stream = capture_stream):
logits=model(inputs_embeds=inputs_embeds, logits=model(inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,
cache_position=cache_position, cache_position=cache_position,
past_key_values=past_key_values, past_key_values=past_key_values,
**kwargs)[0] **kwargs)[0]
capture_stream.wait_stream(torch.cuda.current_stream())
torch.cuda.set_device(main_device)
torch.cuda.set_stream(capture_stream)
past_key_values.change_seq_length(-1) past_key_values.change_seq_length(-1)
torch.cuda.synchronize() torch.cuda.synchronize(self.main_device)
#self.graph.debug_dump("cuda_graph_hooked.dot") #self.graph.debug_dump("cuda_graph_hooked.dot")
# Save the input and output buffers. # Save the input and output buffers.
...@@ -65,7 +75,7 @@ class CUDAGraphRunner: ...@@ -65,7 +75,7 @@ class CUDAGraphRunner:
#print("begin replay") #print("begin replay")
#time.sleep(1) #time.sleep(1)
self.graph.replay() self.graph.replay()
torch.cuda.synchronize() torch.cuda.synchronize(self.main_device)
# Return the output tensor. # Return the output tensor.
return self.output_buffers["logits"] return self.output_buffers["logits"]
......
This diff is collapsed.
This diff is collapsed.
...@@ -3,7 +3,8 @@ requires = [ ...@@ -3,7 +3,8 @@ requires = [
"setuptools", "setuptools",
"torch >= 2.3.0", "torch >= 2.3.0",
"ninja", "ninja",
"packaging" "packaging",
"cpufeature"
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
......
This diff is collapsed.
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#include <cstring> #include <cstring>
#include <type_traits> #include <type_traits>
#if defined __x86_64__ || defined __aarch64__ #if defined __x86_64__ || defined __aarch64__ || defined(_M_X64)
#include "llama.cpp/ggml-impl.h" #include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml-quants.h" #include "llama.cpp/ggml-quants.h"
...@@ -225,7 +225,7 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeA, const voi ...@@ -225,7 +225,7 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeA, const voi
return true; return true;
} }
#if defined __x86_64__ #if defined __x86_64__ || defined(_M_X64)
#if defined HAVE_FANCY_SIMD #if defined HAVE_FANCY_SIMD
#undef HAVE_FANCY_SIMD #undef HAVE_FANCY_SIMD
...@@ -1412,6 +1412,7 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) { ...@@ -1412,6 +1412,7 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int) { bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int) {
if (ne00 % ggml_blck_size(GGML_TYPE_Q8_K) == 0)
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00); row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00);
switch (typeA) { switch (typeA) {
......
...@@ -3,6 +3,6 @@ ...@@ -3,6 +3,6 @@
// Copyrigth 2024 Iwan Kawrakow. // Copyrigth 2024 Iwan Kawrakow.
// Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved.
#ifdef __x86_64__ #if defined(__x86_64__) || defined(_M_X64)
#include "iqk_mul_mat.inc" #include "iqk_mul_mat.inc"
#endif // __x86_64__ #endif // __x86_64__
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment