Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
28ff784a
Commit
28ff784a
authored
Jun 06, 2025
by
fengchao
Browse files
[DAS] Adapt code for dcu
parent
0bc665e9
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
32 additions
and
40 deletions
+32
-40
csrc/custom_marlin/gptq_marlin/gptq_marlin.cu
csrc/custom_marlin/gptq_marlin/gptq_marlin.cu
+2
-2
csrc/custom_marlin/gptq_marlin/gptq_marlin_dtypes.cuh
csrc/custom_marlin/gptq_marlin/gptq_marlin_dtypes.cuh
+19
-19
csrc/ktransformers_ext/CMakeLists.txt
csrc/ktransformers_ext/CMakeLists.txt
+9
-8
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+1
-5
ktransformers/operators/models.py
ktransformers/operators/models.py
+1
-6
No files found.
csrc/custom_marlin/gptq_marlin/gptq_marlin.cu
View file @
28ff784a
...
...
@@ -1765,7 +1765,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
if
(
has_act_order
)
{
// Permute A columns
int
block_rows
=
div_ceil
(
prob_m
,
blocks
);
permute_cols_kernel
<<
<
blocks
,
default_threads
,
0
,
stream
>>
>
(
permute_cols_kernel
<<<
blocks
,
default_threads
,
0
,
stream
>>>
(
A_ptr
,
perm_ptr
,
a_tmp_ptr
,
prob_m
,
prob_k
,
block_rows
);
A_ptr
=
a_tmp_ptr
;
}
...
...
@@ -2032,4 +2032,4 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
return
c
;
}
#endif
\ No newline at end of file
#endif
csrc/custom_marlin/gptq_marlin/gptq_marlin_dtypes.cuh
View file @
28ff784a
...
...
@@ -52,26 +52,26 @@ template <> class ScalarType<nv_bfloat16> {
using
FragC
=
Vec
<
float
,
4
>
;
using
FragS
=
Vec
<
nv_bfloat162
,
1
>
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static
__device__
float
inline
num2float
(
const
nv_bfloat16
x
)
{
return
__bfloat162float
(
x
);
}
static
__device__
nv_bfloat162
inline
num2num2
(
const
nv_bfloat16
x
)
{
return
__bfloat162bfloat162
(
x
);
}
static
__device__
nv_bfloat162
inline
nums2num2
(
const
nv_bfloat16
x1
,
const
nv_bfloat16
x2
)
{
return
__halves2bfloat162
(
x1
,
x2
);
}
static
__host__
__device__
nv_bfloat16
inline
float2num
(
const
float
x
)
{
return
__float2bfloat16
(
x
);
}
#endif
//
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static
__device__
float
inline
num2float
(
const
nv_bfloat16
x
)
{
return
__bfloat162float
(
x
);
}
static
__device__
nv_bfloat162
inline
num2num2
(
const
nv_bfloat16
x
)
{
return
__bfloat162bfloat162
(
x
);
}
static
__device__
nv_bfloat162
inline
nums2num2
(
const
nv_bfloat16
x1
,
const
nv_bfloat16
x2
)
{
return
__halves2bfloat162
(
x1
,
x2
);
}
static
__host__
__device__
nv_bfloat16
inline
float2num
(
const
float
x
)
{
return
__float2bfloat16
(
x
);
}
//
#endif
};
}
// namespace gptq_marlin
#endif
\ No newline at end of file
#endif
csrc/ktransformers_ext/CMakeLists.txt
View file @
28ff784a
...
...
@@ -118,7 +118,8 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
message
(
STATUS
"x86 detected"
)
set
(
HOST_IS_X86 TRUE
)
set
(
HAS_AVX512 TRUE
)
set
(
__HAS_AMX__ TRUE
)
set
(
__HAS_AMX__ False
)
#set(__HAS_AMX__ TRUE)
add_compile_definitions
(
__x86_64__
)
# check AVX512
execute_process
(
...
...
@@ -141,12 +142,12 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
# check AMX
string
(
FIND
"
${
LSCPU_OUTPUT
}
"
"amx"
COMPILER_SUPPORTS_AMX
)
if
(
COMPILER_SUPPORTS_AMX GREATER -1
)
message
(
STATUS
"Compiler supports AMX"
)
add_compile_definitions
(
__HAS_AMX__
)
else
()
message
(
STATUS
"Compiler does NOT support AMX"
)
endif
()
#
if(COMPILER_SUPPORTS_AMX GREATER -1)
#
message(STATUS "Compiler supports AMX")
#
add_compile_definitions(__HAS_AMX__)
#
else()
message
(
STATUS
"Compiler does NOT support AMX"
)
#
endif()
if
(
MSVC
)
# instruction set detection for MSVC only
if
(
LLAMA_NATIVE
)
...
...
@@ -396,4 +397,4 @@ else()
else
()
message
(
STATUS
"NUMA library not found or user not set USE_NUMA - disabling NUMA support"
)
endif
()
endif
()
\ No newline at end of file
endif
()
ktransformers/operators/attention.py
View file @
28ff784a
...
...
@@ -704,11 +704,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_position
,
**
kwargs
,
)
elif
(
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
or
hidden_states
.
device
.
type
==
'cpu'
or
device_manager
.
gpu_vendor
!=
GPUVendor
.
NVIDIA
)
or
(
"Z100"
in
get_device_name
())
or
(
"Z100L"
in
get_device_name
())
or
(
"K100"
in
get_device_name
()):
elif
(
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
or
hidden_states
.
device
.
type
==
'cpu'
or
device_manager
.
gpu_vendor
!=
GPUVendor
.
NVIDIA
or
(
"Z100"
in
get_device_name
())
or
(
"Z100L"
in
get_device_name
())
or
(
"K100"
in
get_device_name
())):
print
(
"for Windows or GPU before ampere or Z100/Z100L or K100, use forward_windows"
)
return
self
.
forward_windows
(
hidden_states
,
...
...
ktransformers/operators/models.py
View file @
28ff784a
...
...
@@ -660,12 +660,7 @@ class KDeepseekV2Model(BaseInjectedModule):
else
:
#if os.name == 'nt' or get_compute_capability()<8:
# print("for Windows or GPU before ampere, use forward_windows")
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
or
(
self
.
transfer_map
is
not
None
and
'cpu'
in
self
.
transfer_map
.
values
())
or
device_manager
.
gpu_vendor
!=
GPUVendor
.
NVIDIA
or
(
self
.
transfer_map
is
not
None
and
'cpu'
in
self
.
transfer_map
.
values
())
or
device_manager
.
gpu_vendor
!=
GPUVendor
.
NVIDIA
or
(
"Z100"
in
get_device_name
())
or
(
"Z100L"
in
get_device_name
())
or
(
"K100"
in
get_device_name
()):
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
or
(
self
.
transfer_map
is
not
None
and
'cpu'
in
self
.
transfer_map
.
values
())
or
device_manager
.
gpu_vendor
!=
GPUVendor
.
NVIDIA
or
(
"Z100"
in
get_device_name
())
or
(
"Z100L"
in
get_device_name
())
or
(
"K100"
in
get_device_name
()):
print
(
"for Windows or GPU before ampere or Z100/Z100L or K100, use forward_windows"
)
# only use mask in forward windows or can't flash attn
causal_mask
=
self
.
_update_causal_mask
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment