Commit 28ff784a authored by fengchao's avatar fengchao
Browse files

[DAS] Adapt code for dcu

parent 0bc665e9
...@@ -1765,7 +1765,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ...@@ -1765,7 +1765,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
if (has_act_order) { if (has_act_order) {
// Permute A columns // Permute A columns
int block_rows = div_ceil(prob_m, blocks); int block_rows = div_ceil(prob_m, blocks);
permute_cols_kernel << <blocks, default_threads, 0, stream >> > ( permute_cols_kernel <<<blocks, default_threads, 0, stream >>> (
A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows); A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows);
A_ptr = a_tmp_ptr; A_ptr = a_tmp_ptr;
} }
......
...@@ -52,24 +52,24 @@ template <> class ScalarType<nv_bfloat16> { ...@@ -52,24 +52,24 @@ template <> class ScalarType<nv_bfloat16> {
using FragC = Vec<float, 4>; using FragC = Vec<float, 4>;
using FragS = Vec<nv_bfloat162, 1>; using FragS = Vec<nv_bfloat162, 1>;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 //#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static __device__ float inline num2float(const nv_bfloat16 x) { static __device__ float inline num2float(const nv_bfloat16 x) {
return __bfloat162float(x); return __bfloat162float(x);
} }
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
return __bfloat162bfloat162(x); return __bfloat162bfloat162(x);
} }
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,
const nv_bfloat16 x2) { const nv_bfloat16 x2) {
return __halves2bfloat162(x1, x2); return __halves2bfloat162(x1, x2);
} }
static __host__ __device__ nv_bfloat16 inline float2num(const float x) { static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
return __float2bfloat16(x); return __float2bfloat16(x);
} }
#endif //#endif
}; };
} // namespace gptq_marlin } // namespace gptq_marlin
......
...@@ -118,7 +118,8 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW ...@@ -118,7 +118,8 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
message(STATUS "x86 detected") message(STATUS "x86 detected")
set(HOST_IS_X86 TRUE) set(HOST_IS_X86 TRUE)
set(HAS_AVX512 TRUE) set(HAS_AVX512 TRUE)
set(__HAS_AMX__ TRUE) set(__HAS_AMX__ False)
#set(__HAS_AMX__ TRUE)
add_compile_definitions(__x86_64__) add_compile_definitions(__x86_64__)
# check AVX512 # check AVX512
execute_process( execute_process(
...@@ -141,12 +142,12 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW ...@@ -141,12 +142,12 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
# check AMX # check AMX
string(FIND "${LSCPU_OUTPUT}" "amx" COMPILER_SUPPORTS_AMX) string(FIND "${LSCPU_OUTPUT}" "amx" COMPILER_SUPPORTS_AMX)
if(COMPILER_SUPPORTS_AMX GREATER -1) #if(COMPILER_SUPPORTS_AMX GREATER -1)
message(STATUS "Compiler supports AMX") # message(STATUS "Compiler supports AMX")
add_compile_definitions(__HAS_AMX__) # add_compile_definitions(__HAS_AMX__)
else() #else()
message(STATUS "Compiler does NOT support AMX") message(STATUS "Compiler does NOT support AMX")
endif() #endif()
if (MSVC) if (MSVC)
# instruction set detection for MSVC only # instruction set detection for MSVC only
if (LLAMA_NATIVE) if (LLAMA_NATIVE)
......
...@@ -704,11 +704,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): ...@@ -704,11 +704,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_position, cache_position,
**kwargs, **kwargs,
) )
elif (os.name == 'nt' or get_compute_capability()<8 elif (os.name == 'nt' or get_compute_capability()<8 or hidden_states.device.type == 'cpu' or device_manager.gpu_vendor != GPUVendor.NVIDIA or ("Z100" in get_device_name()) or ("Z100L" in get_device_name()) or ("K100" in get_device_name())):
or hidden_states.device.type == 'cpu'
or device_manager.gpu_vendor != GPUVendor.NVIDIA)
or ("Z100" in get_device_name())
or ("Z100L" in get_device_name()) or ("K100" in get_device_name()):
print("for Windows or GPU before ampere or Z100/Z100L or K100, use forward_windows") print("for Windows or GPU before ampere or Z100/Z100L or K100, use forward_windows")
return self.forward_windows( return self.forward_windows(
hidden_states, hidden_states,
......
...@@ -660,12 +660,7 @@ class KDeepseekV2Model(BaseInjectedModule): ...@@ -660,12 +660,7 @@ class KDeepseekV2Model(BaseInjectedModule):
else: else:
#if os.name == 'nt' or get_compute_capability()<8: #if os.name == 'nt' or get_compute_capability()<8:
# print("for Windows or GPU before ampere, use forward_windows") # print("for Windows or GPU before ampere, use forward_windows")
if os.name == 'nt' or get_compute_capability()<8 if os.name == 'nt' or get_compute_capability()<8 or (self.transfer_map is not None and 'cpu' in self.transfer_map.values()) or device_manager.gpu_vendor != GPUVendor.NVIDIA or ("Z100" in get_device_name()) or ("Z100L" in get_device_name()) or ("K100" in get_device_name()):
or (self.transfer_map is not None and 'cpu' in self.transfer_map.values())
or device_manager.gpu_vendor != GPUVendor.NVIDIA
or (self.transfer_map is not None and 'cpu' in self.transfer_map.values())
or device_manager.gpu_vendor != GPUVendor.NVIDIA
or ("Z100" in get_device_name()) or ("Z100L" in get_device_name()) or ("K100" in get_device_name()):
print("for Windows or GPU before ampere or Z100/Z100L or K100, use forward_windows") print("for Windows or GPU before ampere or Z100/Z100L or K100, use forward_windows")
# only use mask in forward windows or can't flash attn # only use mask in forward windows or can't flash attn
causal_mask = self._update_causal_mask( causal_mask = self._update_causal_mask(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment