Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
10572e55
Commit
10572e55
authored
Nov 21, 2025
by
zhuyue
Browse files
Issue/654 - Update CUB API usage for CUDA 12.9+ compatibility
parent
d18b77a0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
15 deletions
+23
-15
src/infiniop/ops/layer_norm/cuda/kernel.cuh
src/infiniop/ops/layer_norm/cuda/kernel.cuh
+1
-1
src/infiniop/ops/lp_norm/cuda/kernel.cuh
src/infiniop/ops/lp_norm/cuda/kernel.cuh
+10
-2
src/infiniop/ops/tanh/operator.cc
src/infiniop/ops/tanh/operator.cc
+12
-12
No files found.
src/infiniop/ops/layer_norm/cuda/kernel.cuh
View file @
10572e55
...
...
@@ -81,7 +81,7 @@ __device__ void blockLayernormKernel(T *output, T const *input, T const *weight,
}
__shared__
float
sigma2
;
float
sigma2_block
=
BlockReduce
(
temp_storage
).
Reduce
(
sigma2_partial
,
cub
::
Sum
()
);
float
sigma2_block
=
BlockReduce
(
temp_storage
).
Sum
(
sigma2_partial
);
if
(
threadIdx
.
x
==
0
)
{
float
sigma_tmp
=
sqrt
(
sigma2_block
*
__fdividef
(
1.0
F
,
dimsize
)
+
eps
);
sigma2
=
__fdividef
(
1.0
F
,
sigma_tmp
);
...
...
src/infiniop/ops/lp_norm/cuda/kernel.cuh
View file @
10572e55
...
...
@@ -17,7 +17,11 @@ __device__ void blockLPNormKernel(
local_max
=
max
(
local_max
,
fabsf
((
float
)
input
[
tid
+
ind
*
stride
]));
}
__shared__
float
global_max
;
#if CUDART_VERSION >= 12090
float
max_block
=
BlockReduce
(
temp_storage
).
Reduce
(
local_max
,
::
cuda
::
maximum
());
#else
float
max_block
=
BlockReduce
(
temp_storage
).
Reduce
(
local_max
,
cub
::
Max
());
#endif
if
(
threadIdx
.
x
==
0
)
{
// must set threadIdx.x = 0 write the output to memory
global_max
=
max_block
;
}
...
...
@@ -30,7 +34,7 @@ __device__ void blockLPNormKernel(
}
__shared__
float
p_total
;
float
p_block
=
BlockReduce
(
temp_storage
).
Reduce
(
p_partial
,
cub
::
Sum
()
);
float
p_block
=
BlockReduce
(
temp_storage
).
Sum
(
p_partial
);
if
(
threadIdx
.
x
==
0
)
{
// must set threadIdx.x = 0 write the output to memory
p_total
=
powf
(
p_block
,
1.0
f
/
p
);
}
...
...
@@ -69,7 +73,11 @@ __device__ void blockLPNormStridesKernel(
local_max
=
max
(
local_max
,
fabsf
((
float
)
input
[
ind_i
+
ind
]));
}
__shared__
float
global_max
;
#if CUDART_VERSION >= 12090
float
max_block
=
BlockReduce
(
temp_storage
).
Reduce
(
local_max
,
::
cuda
::
maximum
());
#else
float
max_block
=
BlockReduce
(
temp_storage
).
Reduce
(
local_max
,
cub
::
Max
());
#endif
if
(
threadIdx
.
x
==
0
)
{
// must set threadIdx.x = 0 write the output to memory
global_max
=
max_block
;
}
...
...
@@ -82,7 +90,7 @@ __device__ void blockLPNormStridesKernel(
}
__shared__
float
p_total
;
float
p_block
=
BlockReduce
(
temp_storage
).
Reduce
(
p_partial
,
cub
::
Sum
()
);
float
p_block
=
BlockReduce
(
temp_storage
).
Sum
(
p_partial
);
if
(
threadIdx
.
x
==
0
)
{
// must set threadIdx.x = 0 write the output to memory
p_total
=
powf
(
p_block
,
1.0
f
/
p
);
}
...
...
src/infiniop/ops/tanh/operator.cc
View file @
10572e55
...
...
@@ -40,9 +40,9 @@ __C infiniStatus_t infiniopCreateTanhDescriptor(
#ifdef ENABLE_QY_API
CREATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
// #ifdef ENABLE_METAX_API
// CREATE(INFINI_DEVICE_METAX, metax);
// #endif
// #ifdef ENABLE_METAX_API
// CREATE(INFINI_DEVICE_METAX, metax);
// #endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -71,9 +71,9 @@ __C infiniStatus_t infiniopGetTanhWorkspaceSize(infiniopTanhDescriptor_t desc, s
#ifdef ENABLE_QY_API
GET
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
// #ifdef ENABLE_METAX_API
// GET(INFINI_DEVICE_METAX, metax);
// #endif
// #ifdef ENABLE_METAX_API
// GET(INFINI_DEVICE_METAX, metax);
// #endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
...
...
@@ -109,9 +109,9 @@ __C infiniStatus_t infiniopTanh(
#ifdef ENABLE_QY_API
CALCULATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
// #ifdef ENABLE_METAX_API
// CALCULATE(INFINI_DEVICE_METAX, metax);
// #endif
// #ifdef ENABLE_METAX_API
// CALCULATE(INFINI_DEVICE_METAX, metax);
// #endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -142,9 +142,9 @@ infiniopDestroyTanhDescriptor(infiniopTanhDescriptor_t desc) {
#ifdef ENABLE_QY_API
DELETE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
// #ifdef ENABLE_METAX_API
// DELETE(INFINI_DEVICE_METAX, metax);
// #endif
// #ifdef ENABLE_METAX_API
// DELETE(INFINI_DEVICE_METAX, metax);
// #endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment