Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
9725b043
"megatron/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "9026b86d8a4c5eacab3cc9464654da70a772d328"
Commit
9725b043
authored
Jan 07, 2020
by
rusty1s
Browse files
clean up reduction type
parent
9a91c42d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
137 additions
and
186 deletions
+137
-186
cuda/segment_kernel.cu
cuda/segment_kernel.cu
+137
-186
No files found.
cuda/segment_kernel.cu
View file @
9725b043
...
...
@@ -10,17 +10,71 @@
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
#define FULL_MASK 0xffffffff
#define ADD 0
#define MEAN 1
#define MIN 2
#define MAX 3
enum
ReductionType
{
ADD
,
MEAN
,
MIN
,
MAX
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
if (reduce == "add") { \
const ReductionType REDUCE = ADD; \
return __VA_ARGS__(); \
} else if (reduce == "mean") { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} else if (reduce == "min") { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} else if (reduce == "max") { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
}()
template
<
typename
scalar_t
,
ReductionType
REDUCE
>
struct
Reducer
{
static
inline
__host__
__device__
scalar_t
init
()
{
if
(
REDUCE
==
MIN
)
{
return
std
::
numeric_limits
<
scalar_t
>::
max
();
}
else
if
(
REDUCE
==
MAX
)
{
return
std
::
numeric_limits
<
scalar_t
>::
min
();
}
else
{
return
(
scalar_t
)
0
;
}
}
static
inline
__host__
__device__
void
update
(
scalar_t
*
val
,
scalar_t
new_val
,
int64_t
*
arg
,
int64_t
new_arg
)
{
if
((
REDUCE
==
MIN
&&
new_val
<
*
val
)
||
(
REDUCE
==
MAX
&&
new_val
>
*
val
))
{
*
val
=
new_val
;
*
arg
=
new_arg
;
}
else
{
*
val
=
*
val
+
new_val
;
}
}
static
inline
__host__
__device__
void
write
(
scalar_t
*
address
,
scalar_t
val
,
int64_t
*
arg_address
,
int64_t
arg
,
int
count
)
{
if
(
REDUCE
==
ADD
)
{
*
address
=
val
;
}
else
if
(
REDUCE
==
MEAN
)
{
*
address
=
val
/
(
scalar_t
)
max
(
count
,
1
);
}
else
if
(
REDUCE
==
MIN
||
REDUCE
==
MAX
)
{
if
(
count
>
0
)
{
*
address
=
val
;
*
arg_address
=
arg
;
}
else
{
*
address
=
(
scalar_t
)
0
;
}
}
}
};
// We need our own `IndexToOffset` implementation since we do not want to access
// the last element of the `indexptr`.
template
<
typename
T
,
typename
I
>
struct
IndexPtrToOffset
{
static
__host__
__device__
I
get
(
I
idx
,
const
at
::
cuda
::
detail
::
TensorInfo
<
T
,
I
>
&
info
)
{
I
offset
=
idx
%
(
info
.
sizes
[
info
.
dims
-
1
]
-
1
);
template
<
typename
scalar_t
>
struct
IndexPtrToOffset
{
static
inline
__host__
__device__
int
get
(
int
idx
,
const
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int
>
&
info
)
{
int
offset
=
idx
%
(
info
.
sizes
[
info
.
dims
-
1
]
-
1
);
offset
*=
info
.
strides
[
info
.
dims
-
1
];
idx
/=
info
.
sizes
[
info
.
dims
-
1
]
-
1
;
for
(
int
i
=
info
.
dims
-
2
;
i
>=
0
;
--
i
)
{
...
...
@@ -31,170 +85,85 @@ template <typename T, typename I> struct IndexPtrToOffset {
}
};
template
<
typename
scalar_t
,
int
REDUCE
,
int
TB
>
__global__
void
segment_add_csr_kernel
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
indptr_info
,
scalar_t
*
out_data
,
int64_t
*
arg_out_data
,
size_t
N
,
size_t
E
)
{
template
<
typename
scalar_t
,
ReductionType
REDUCE
,
int
TB
>
__global__
void
segment_csr_kernel
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
indptr_info
,
scalar_t
*
out_data
,
int64_t
*
arg_out_data
,
size_t
N
,
size_t
E
)
{
// Each warp processes exactly `32/TB` rows.
// Each warp processes exactly `32/TB` rows and aggregates all row values via
// a parallel reduction.
int
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
row_idx
=
thread_idx
/
TB
;
int
lane_idx
=
thread_idx
&
(
TB
-
1
);
if
(
row_idx
<
N
)
{
int
offset
=
IndexPtrToOffset
<
int64_t
,
int
>::
get
(
row_idx
,
indptr_info
);
int
offset
=
IndexPtrToOffset
<
int64_t
>::
get
(
row_idx
,
indptr_info
);
int
row_start
=
__ldg
(
indptr_info
.
data
+
offset
);
int
row_end
=
__ldg
(
indptr_info
.
data
+
offset
+
indptr_info
.
strides
[
indptr_info
.
dims
-
1
]);
scalar_t
val
,
tmp
;
int64_t
arg_val
,
arg_tmp
;
if
(
REDUCE
==
ADD
)
{
val
=
(
scalar_t
)
0
;
}
else
if
(
REDUCE
==
MEAN
)
{
val
=
(
scalar_t
)
0
;
}
else
if
(
REDUCE
==
MIN
)
{
val
=
std
::
numeric_limits
<
scalar_t
>::
max
();
}
else
if
(
REDUCE
==
MAX
)
{
val
=
std
::
numeric_limits
<
scalar_t
>::
min
();
}
scalar_t
val
=
Reducer
<
scalar_t
,
REDUCE
>::
init
();
int64_t
arg
,
tmp
;
offset
=
(
row_idx
/
(
indptr_info
.
sizes
[
indptr_info
.
dims
-
1
]
-
1
))
*
E
;
for
(
int
src_idx
=
row_start
+
lane_idx
;
src_idx
<
row_end
;
src_idx
+=
TB
)
{
tmp
=
src_data
[
offset
+
src_idx
];
// "Mostly" coalesced read.
if
(
REDUCE
==
ADD
)
{
val
+=
tmp
;
}
else
if
(
REDUCE
==
MEAN
)
{
val
+=
tmp
;
}
else
if
(
REDUCE
==
MIN
&&
tmp
<
val
)
{
val
=
tmp
;
arg_val
=
src_idx
;
}
else
if
(
REDUCE
==
MAX
&&
tmp
>
val
)
{
val
=
tmp
;
arg_val
=
src_idx
;
}
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
src_data
[
offset
+
src_idx
],
&
arg
,
src_idx
);
}
#pragma unroll
for
(
int
i
=
TB
/
2
;
i
>
0
;
i
/=
2
)
{
// Parallel reduction inside a single warp.
tmp
=
__shfl_down_sync
(
FULL_MASK
,
val
,
i
);
if
(
REDUCE
==
ADD
)
{
val
+=
tmp
;
}
else
if
(
REDUCE
==
MEAN
)
{
val
+=
tmp
;
}
else
if
(
REDUCE
==
MIN
)
{
arg_tmp
=
__shfl_down_sync
(
FULL_MASK
,
arg_val
,
i
);
if
(
tmp
<
val
)
{
val
=
tmp
;
arg_val
=
arg_tmp
;
}
}
else
if
(
REDUCE
==
MAX
)
{
arg_tmp
=
__shfl_down_sync
(
FULL_MASK
,
arg_val
,
i
);
if
(
tmp
>
val
)
{
val
=
tmp
;
arg_val
=
arg_tmp
;
}
if
(
REDUCE
==
MIN
||
REDUCE
==
MAX
)
{
tmp
=
__shfl_down_sync
(
FULL_MASK
,
arg
,
i
);
}
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
__shfl_down_sync
(
FULL_MASK
,
val
,
i
),
&
arg
,
tmp
);
}
if
(
lane_idx
==
0
)
{
// "Mostly" coalesced write.
if
(
REDUCE
==
ADD
)
{
out_data
[
row_idx
]
=
val
;
}
else
if
(
REDUCE
==
MEAN
)
{
out_data
[
row_idx
]
=
val
/
(
scalar_t
)
max
(
row_end
-
row_start
,
1
);
}
else
if
(
REDUCE
==
MIN
)
{
if
(
row_end
-
row_start
>
0
)
{
out_data
[
row_idx
]
=
val
;
arg_out_data
[
row_idx
]
=
arg_val
;
}
else
{
out_data
[
row_idx
]
=
0
;
}
}
else
if
(
REDUCE
==
MAX
)
{
if
(
row_end
-
row_start
>
0
)
{
out_data
[
row_idx
]
=
val
;
arg_out_data
[
row_idx
]
=
arg_val
;
}
else
{
out_data
[
row_idx
]
=
0
;
}
}
Reducer
<
scalar_t
,
REDUCE
>::
write
(
out_data
+
row_idx
,
val
,
arg_out_data
+
row_idx
,
arg
,
row_end
-
row_start
);
}
}
}
template
<
typename
scalar_t
,
int
REDUCE
>
__global__
void
segment_
add_
csr_broadcast_kernel
(
template
<
typename
scalar_t
,
ReductionType
REDUCE
>
__global__
void
segment_csr_broadcast_kernel
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
indptr_info
,
scalar_t
*
out_data
,
int64_t
*
arg_out_data
,
size_t
N
,
size_t
K
,
size_t
E
)
{
// Each thread processes exactly one row. It turned out that is more efficient
// than using shared memory due to avoiding synchronization barriers.
// Each thread processes exactly one row. It turned out that is more
// efficient than using shared memory due to avoiding synchronization
// barriers.
int
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
row_idx
=
thread_idx
/
K
;
int
lane_idx
=
thread_idx
%
K
;
if
(
thread_idx
<
N
*
K
)
{
int
offset
=
IndexPtrToOffset
<
int64_t
,
int
>::
get
(
row_idx
,
indptr_info
);
int
offset
=
IndexPtrToOffset
<
int64_t
>::
get
(
row_idx
,
indptr_info
);
int
row_start
=
__ldg
(
indptr_info
.
data
+
offset
);
int
row_end
=
__ldg
(
indptr_info
.
data
+
offset
+
indptr_info
.
strides
[
indptr_info
.
dims
-
1
]);
scalar_t
val
,
tmp
;
int64_t
arg_val
;
if
(
REDUCE
==
ADD
)
{
val
=
(
scalar_t
)
0
;
}
else
if
(
REDUCE
==
MEAN
)
{
val
=
(
scalar_t
)
0
;
}
else
if
(
REDUCE
==
MIN
)
{
val
=
std
::
numeric_limits
<
scalar_t
>::
max
();
}
else
if
(
REDUCE
==
MAX
)
{
val
=
std
::
numeric_limits
<
scalar_t
>::
min
();
}
scalar_t
val
=
Reducer
<
scalar_t
,
REDUCE
>::
init
();
int64_t
arg
;
offset
=
(
row_idx
/
(
indptr_info
.
sizes
[
indptr_info
.
dims
-
1
]
-
1
))
*
E
*
K
;
for
(
int
src_idx
=
row_start
;
src_idx
<
row_end
;
src_idx
++
)
{
tmp
=
src_data
[
offset
+
K
*
src_idx
+
lane_idx
];
// Coalesced read.
if
(
REDUCE
==
ADD
)
{
val
+=
tmp
;
}
else
if
(
REDUCE
==
MEAN
)
{
val
+=
tmp
;
}
else
if
(
REDUCE
==
MIN
&&
tmp
<
val
)
{
val
=
tmp
;
arg_val
=
src_idx
;
}
else
if
(
REDUCE
==
MAX
&&
tmp
>
val
)
{
val
=
tmp
;
arg_val
=
src_idx
;
}
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
src_data
[
offset
+
K
*
src_idx
+
lane_idx
],
&
arg
,
src_idx
);
}
// Coalesced write.
if
(
REDUCE
==
ADD
)
{
out_data
[
thread_idx
]
=
val
;
}
else
if
(
REDUCE
==
MEAN
)
{
out_data
[
thread_idx
]
=
val
/
(
scalar_t
)
max
(
row_end
-
row_start
,
1
);
}
else
if
(
REDUCE
==
MIN
)
{
if
(
row_end
-
row_start
>
0
)
{
out_data
[
thread_idx
]
=
val
;
arg_out_data
[
thread_idx
]
=
arg_val
;
}
else
{
out_data
[
thread_idx
]
=
0
;
}
}
else
if
(
REDUCE
==
MAX
)
{
if
(
row_end
-
row_start
>
0
)
{
out_data
[
thread_idx
]
=
val
;
arg_out_data
[
thread_idx
]
=
arg_val
;
}
else
{
out_data
[
thread_idx
]
=
0
;
}
}
Reducer
<
scalar_t
,
REDUCE
>::
write
(
out_data
+
thread_idx
,
val
,
arg_out_data
+
thread_idx
,
arg
,
row_end
-
row_start
);
}
}
...
...
@@ -223,14 +192,15 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
}
at
::
optional
<
at
::
Tensor
>
arg_out
=
at
::
nullopt
;
int64_t
*
arg_out_data
=
nullptr
;
if
(
reduce
==
"min"
||
reduce
==
"max"
)
{
arg_out
=
at
::
full_like
(
out
,
src
.
size
(
reduce_dim
),
indptr
.
options
());
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
}
auto
N
=
out
.
size
(
reduce_dim
)
*
(
indptr
.
numel
()
/
indptr
.
size
(
-
1
));
auto
K
=
out
.
numel
()
/
N
;
auto
E
=
src
.
size
(
reduce_dim
);
// auto avg_len = (float)src.size(reduce_dim) / (float)out.size(reduce_dim);
auto
indptr_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int
>
(
indptr
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
...
...
@@ -238,56 +208,27 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
auto
src_data
=
src
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
// Select the right kernel based on the reduce operation and whether we need
// broadcasting capabilties (K > 1):
if
(
K
==
1
&&
reduce
==
"add"
)
{
segment_add_csr_kernel
<
scalar_t
,
ADD
,
1
>
<<<
BLOCKS
(
32
,
N
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_info
,
out_data
,
nullptr
,
N
,
E
);
}
else
if
(
K
==
1
&&
reduce
==
"mean"
)
{
segment_add_csr_kernel
<
scalar_t
,
MEAN
,
1
>
<<<
BLOCKS
(
32
,
N
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_info
,
out_data
,
nullptr
,
N
,
E
);
}
else
if
(
K
==
1
&&
reduce
==
"min"
)
{
auto
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
segment_add_csr_kernel
<
scalar_t
,
MIN
,
1
>
<<<
BLOCKS
(
32
,
N
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_info
,
out_data
,
arg_out_data
,
N
,
E
);
}
else
if
(
K
==
1
&&
reduce
==
"max"
)
{
auto
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
segment_add_csr_kernel
<
scalar_t
,
MAX
,
1
>
<<<
BLOCKS
(
32
,
N
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_info
,
out_data
,
arg_out_data
,
N
,
E
);
}
else
if
(
reduce
==
"add"
)
{
segment_add_csr_broadcast_kernel
<
scalar_t
,
ADD
>
<<<
BLOCKS
(
1
,
N
*
K
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_info
,
out_data
,
nullptr
,
N
,
K
,
E
);
}
else
if
(
reduce
==
"mean"
)
{
segment_add_csr_broadcast_kernel
<
scalar_t
,
MEAN
>
<<<
BLOCKS
(
1
,
N
*
K
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_info
,
out_data
,
nullptr
,
N
,
K
,
E
);
}
else
if
(
reduce
==
"min"
)
{
auto
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
segment_add_csr_broadcast_kernel
<
scalar_t
,
MIN
>
<<<
BLOCKS
(
1
,
N
*
K
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_info
,
out_data
,
arg_out_data
,
N
,
K
,
E
);
}
else
if
(
reduce
==
"max"
)
{
auto
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
segment_add_csr_broadcast_kernel
<
scalar_t
,
MAX
>
<<<
BLOCKS
(
1
,
N
*
K
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_info
,
out_data
,
arg_out_data
,
N
,
K
,
E
);
}
AT_DISPATCH_REDUCTION_TYPES
(
reduce
,
[
&
]
{
if
(
K
==
1
)
{
segment_csr_kernel
<
scalar_t
,
REDUCE
,
1
>
<<<
BLOCKS
(
32
,
N
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_info
,
out_data
,
arg_out_data
,
N
,
E
);
}
else
{
segment_csr_broadcast_kernel
<
scalar_t
,
REDUCE
>
<<<
BLOCKS
(
1
,
N
*
K
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_info
,
out_data
,
arg_out_data
,
N
,
K
,
E
);
}
});
});
return
std
::
make_tuple
(
out
,
arg_out
);
}
template
<
typename
scalar_t
,
int
REDUCE
>
__global__
void
segment_add_coo_kernel
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
index_info
,
scalar_t
*
out_data
,
int64_t
*
arg_out_data
,
size_t
E
)
{
template
<
typename
scalar_t
,
ReductionType
REDUCE
>
__global__
void
segment_coo_kernel
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
index_info
,
scalar_t
*
out_data
,
int64_t
*
arg_out_data
,
size_t
E
)
{
// Each thread processes exactly one entry. Within a warp, we perform a
// parallel reduction across equal indices, and write the intermediate
...
...
@@ -318,15 +259,15 @@ __global__ void segment_add_coo_kernel(
}
}
template
<
typename
scalar_t
,
int
REDUCE
,
int
TB
>
__global__
void
segment_
add_
coo_broadcast_kernel
(
template
<
typename
scalar_t
,
ReductionType
REDUCE
,
int
TB
>
__global__
void
segment_coo_broadcast_kernel
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
index_info
,
scalar_t
*
out_data
,
int64_t
*
arg_out_data
,
size_t
E
,
size_t
K
)
{
// Each thread processes a single column and `TB`
row
s. Coalesced
read and
// write is performed in column-major order. The intermediate
results are
// written via atomics.
// Each thread processes a single column and `TB`
index entrie
s. Coalesced
//
read and
write is performed in column-major order. The intermediate
//
results are
written via atomics.
int
row_start
=
(
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
)
*
TB
;
int
col_idx
=
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
...
...
@@ -392,24 +333,34 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
// Select the right kernel based on average row length (purely heuristic)
// and whether we need broadcasting capabilties (K > 1):
if
(
K
==
1
)
segment_add_coo_kernel
<
scalar_t
,
ADD
>
<<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
nullptr
,
E
);
else
if
(
avg_len
<=
8
)
segment_add_coo_broadcast_kernel
<
scalar_t
,
ADD
,
4
>
if
(
K
==
1
&&
reduce
==
"add"
)
{
segment_coo_kernel
<
scalar_t
,
ADD
><<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
nullptr
,
E
);
}
else
if
(
K
==
1
&&
reduce
==
"mean"
)
{
segment_coo_kernel
<
scalar_t
,
MEAN
><<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
nullptr
,
E
);
}
else
if
(
K
==
1
&&
reduce
==
"min"
)
{
auto
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
segment_coo_kernel
<
scalar_t
,
MIN
><<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
arg_out_data
,
E
);
}
else
if
(
K
==
1
&&
reduce
==
"max"
)
{
auto
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
segment_coo_kernel
<
scalar_t
,
MAX
><<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
arg_out_data
,
E
);
}
else
if
(
avg_len
<=
8
)
segment_coo_broadcast_kernel
<
scalar_t
,
ADD
,
4
>
<<<
dim3
(((
E
+
(
8
*
4
)
-
1
)
/
(
8
*
4
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
nullptr
,
E
,
K
);
else
if
(
avg_len
<=
16
)
segment_
add_
coo_broadcast_kernel
<
scalar_t
,
ADD
,
8
>
segment_coo_broadcast_kernel
<
scalar_t
,
ADD
,
8
>
<<<
dim3
(((
E
+
(
8
*
8
)
-
1
)
/
(
8
*
8
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
nullptr
,
E
,
K
);
else
if
(
avg_len
<=
32
)
segment_
add_
coo_broadcast_kernel
<
scalar_t
,
ADD
,
16
>
segment_coo_broadcast_kernel
<
scalar_t
,
ADD
,
16
>
<<<
dim3
(((
E
+
(
8
*
16
)
-
1
)
/
(
8
*
16
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
nullptr
,
E
,
K
);
else
segment_
add_
coo_broadcast_kernel
<
scalar_t
,
ADD
,
32
>
segment_coo_broadcast_kernel
<
scalar_t
,
ADD
,
32
>
<<<
dim3
(((
E
+
(
8
*
32
)
-
1
)
/
(
8
*
32
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
nullptr
,
E
,
K
);
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment