Commit 0fbecc43 authored by maxiao1's avatar maxiao1
Browse files

Merge branch 'v0.5.4_dev_maxiao' into 'v0.5.4_dev'

change sgl_kernel WARP_SIZE to 64

See merge request OpenDAS/sglang!3
parents 8fc55263 75cd34d1
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#define INTRIN_M 16 #define INTRIN_M 16
#define INTRIN_N 16 #define INTRIN_N 16
#define INTRIN_K 32 #define INTRIN_K 32
#define WARP_SIZE 32 #define WARP_SIZE 64
#define SMEM_PAD_A 0 #define SMEM_PAD_A 0
#define SMEM_PAD_B 0 #define SMEM_PAD_B 0
#define PACK_SIZE 16 #define PACK_SIZE 16
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#define INTRIN_M 16 #define INTRIN_M 16
#define INTRIN_N 16 #define INTRIN_N 16
#define INTRIN_K 32 #define INTRIN_K 32
#define WARP_SIZE 32 #define WARP_SIZE 64
#define SMEM_PAD_A 0 #define SMEM_PAD_A 0
#define SMEM_PAD_B 0 #define SMEM_PAD_B 0
#define PACK_SIZE 16 #define PACK_SIZE 16
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <cstdint> #include <cstdint>
#ifndef USE_ROCM #ifndef USE_ROCM
#define WARP_SIZE 32 #define WARP_SIZE 64
#include "pytorch_extension_utils.h" #include "pytorch_extension_utils.h"
#else #else
#include "pytorch_extension_utils_rocm.h" #include "pytorch_extension_utils_rocm.h"
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-common.h // copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-common.h
#define QK_K 256 #define QK_K 256
#define K_QUANTS_PER_ITERATION 2 #define K_QUANTS_PER_ITERATION 2
#define WARP_SIZE_GGUF 32 #define WARP_SIZE_GGUF 64
#define K_SCALE_SIZE 12 #define K_SCALE_SIZE 12
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
#define CUDA_QUANTIZE_BLOCK_SIZE 256 #define CUDA_QUANTIZE_BLOCK_SIZE 256
......
...@@ -340,7 +340,7 @@ inline bool getEnvEnablePDL() { ...@@ -340,7 +340,7 @@ inline bool getEnvEnablePDL() {
#define CEILDIV(x, y) (((x) + (y) - 1) / (y)) #define CEILDIV(x, y) (((x) + (y) - 1) / (y))
#ifndef USE_ROCM #ifndef USE_ROCM
#define WARP_SIZE 32 #define WARP_SIZE 64
#else #else
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__) #if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
#define WARP_SIZE 64 #define WARP_SIZE 64
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment