Commit 40fdded5 authored by Zimin Li's avatar Zimin Li
Browse files

issue/127: fix CUDA mix-precision broadcasting input mismatch issue, adjust...

issue/127: fix CUDA mix-precision broadcasting input mismatch issue, adjust comment structure and template variable order
parent b0f75278
......@@ -43,7 +43,7 @@ __device__ __forceinline__ void call_expand(Lambda lambda, std::index_sequence<I
* @param offset Linear offset to support partitioned execution.
* @param args Additional arguments passed to the operator.
*/
template <typename Op, typename Tdata, size_t N, typename... Args>
template <size_t N, typename Op, typename Tdata, typename... Args>
INFINIOP_CUDA_KERNEL elementwise_kernel(
size_t output_size,
size_t ndim,
......@@ -129,6 +129,7 @@ __device__ void launch_op(
const bool *__restrict__ input_broadcasted,
const size_t *__restrict__ const *__restrict__ input_shapes,
const ptrdiff_t *__restrict__ const *__restrict__ input_strides,
const ptrdiff_t *__restrict__ output_strides,
const void *const *__restrict__ inputs,
Tout *output,
std::index_sequence<Is...>) {
......@@ -137,7 +138,7 @@ __device__ void launch_op(
return input_contiguous[input_id]
? idx
: (input_broadcasted[input_id]
? device::cuda::indexToReducedOffset(idx, ndim, input_strides[0], input_strides[input_id])
? device::cuda::indexToReducedOffset(idx, ndim, output_strides, input_strides[input_id])
: device::cuda::indexToOffset(idx, ndim, input_shapes[input_id], input_strides[input_id]));
};
......@@ -200,6 +201,7 @@ INFINIOP_CUDA_KERNEL elementwise_kernel(
input_broadcasted,
input_shapes,
input_strides,
output_strides,
inputs,
output,
std::index_sequence_for<Tin...>{});
......@@ -269,7 +271,7 @@ struct DeviceImpl::Opaque {
size_t step = gridDims.x * blockDims.x;
for (size_t i = 0; i < info.output_size; i += step) {
elementwise_kernel<Op, Tdata, N, Args...><<<gridDims, blockDims, 0, stream>>>(
elementwise_kernel<N, Op, Tdata, Args...><<<gridDims, blockDims, 0, stream>>>(
info.output_size,
info.ndim,
info.output_contiguous,
......@@ -400,8 +402,8 @@ private:
cudaStream_t stream) const {
CHECK_CUDA(cudaMallocAsync(&d_bools, 2 * info.input_size * sizeof(*d_bools), stream));
CHECK_CUDA(cudaMemcpyAsync((void *)d_bools, info.input_contiguous, info.input_size * sizeof(bool), cudaMemcpyHostToDevice, stream));
CHECK_CUDA(cudaMemcpyAsync((void *)(d_bools + info.input_size), info.input_broadcasted, info.input_size * sizeof(bool), cudaMemcpyHostToDevice, stream));
CHECK_CUDA(cudaMemcpyAsync((void *)d_bools, info.input_contiguous, info.input_size * sizeof(*d_bools), cudaMemcpyHostToDevice, stream));
CHECK_CUDA(cudaMemcpyAsync((void *)(d_bools + info.input_size), info.input_broadcasted, info.input_size * sizeof(*d_bools), cudaMemcpyHostToDevice, stream));
CHECK_CUDA(cudaMallocAsync(&d_output_shape_strides, info.ndim * (sizeof(*d_output_shape) + sizeof(*d_output_strides)), stream));
CHECK_CUDA(cudaMemcpyAsync((void *)d_output_shape_strides, info.output_shape, info.ndim * sizeof(*d_output_shape), cudaMemcpyHostToDevice, stream));
......@@ -411,24 +413,24 @@ private:
for (size_t i = 0; i < info.input_size; ++i) {
CHECK_CUDA(cudaMallocAsync(&tmp_device_ptrs[i], info.ndim * sizeof(*&tmp_device_ptrs[i]), stream));
CHECK_CUDA(cudaMemcpyAsync((void *)tmp_device_ptrs[i], info.input_shapes[i],
info.ndim * sizeof(size_t), cudaMemcpyHostToDevice, stream));
info.ndim * sizeof(*tmp_device_ptrs[i]), cudaMemcpyHostToDevice, stream));
}
CHECK_CUDA(cudaMemcpyAsync((void *)d_input_shapes, tmp_device_ptrs.data(),
info.input_size * sizeof(size_t *), cudaMemcpyHostToDevice, stream));
info.input_size * sizeof(*d_input_shapes), cudaMemcpyHostToDevice, stream));
CHECK_CUDA(cudaMallocAsync(&d_input_strides, info.input_size * sizeof(*d_input_strides), stream));
for (size_t i = 0; i < info.input_size; ++i) {
CHECK_CUDA(cudaMallocAsync(&tmp_device_ptrs_strides[i], info.ndim * sizeof(*&tmp_device_ptrs_strides[i]), stream));
CHECK_CUDA(cudaMallocAsync(&tmp_device_ptrs_strides[i], info.ndim * sizeof(*tmp_device_ptrs_strides[i]), stream));
CHECK_CUDA(cudaMemcpyAsync((void *)tmp_device_ptrs_strides[i], info.input_strides[i],
info.ndim * sizeof(ptrdiff_t), cudaMemcpyHostToDevice, stream));
info.ndim * sizeof(*tmp_device_ptrs_strides[i]), cudaMemcpyHostToDevice, stream));
}
CHECK_CUDA(cudaMemcpyAsync((void *)d_input_strides, tmp_device_ptrs_strides.data(),
info.input_size * sizeof(ptrdiff_t *), cudaMemcpyHostToDevice, stream));
info.input_size * sizeof(*d_input_strides), cudaMemcpyHostToDevice, stream));
d_input_contiguous = d_bools;
d_input_broadcasted = d_bools + info.input_size;
d_output_shape = reinterpret_cast<const size_t *>(d_output_shape_strides);
d_output_strides = reinterpret_cast<const ptrdiff_t *>(d_output_shape_strides + info.ndim * sizeof(size_t));
d_output_strides = reinterpret_cast<const ptrdiff_t *>(d_output_shape_strides + info.ndim * sizeof(*d_output_shape));
return INFINI_STATUS_SUCCESS;
}
......@@ -470,23 +472,7 @@ infiniStatus_t DeviceImpl::create(DeviceImpl **device_info,
return INFINI_STATUS_SUCCESS;
}
/**
* @brief Launches elementwise operation where input types may differ.
*
* Dispatches to templated `calculateImpl` using specified output and input types.
*
* @tparam BLOCK_SIZE Number of threads per block.
* @tparam Op Operation functor defining the computation.
* @tparam Tout Output data type.
* @tparam Tin... Input data types (must match Op::num_inputs).
* @tparam Args... Additional arguments passed to the operation.
* @param info Metadata describing tensor shapes, strides, etc.
* @param output Pointer to output buffer on device.
* @param inputs Vector of input pointers (device memory).
* @param stream CUDA stream (opaque void*).
* @param args (UNUSED) Additional operation-specific arguments.
* @return infiniStatus_t Status indicating success or failure.
*/
/* Invoke elementwise operation for different input types */
template <unsigned int BLOCK_SIZE, typename Op, typename Tout, typename... Tin, typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int>>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
......@@ -503,22 +489,7 @@ infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &inf
std::forward<Args>(args)...);
}
/**
* @brief Launches elementwise operation where all input types are the same.
*
* Calls the corresponding templated `calculateImpl` with a unified input type.
*
* @tparam BLOCK_SIZE Number of threads per block.
* @tparam Op Operation functor defining the computation.
* @tparam Tdata Data type for both input and output tensors.
* @tparam Args... Additional arguments passed to the operation.
* @param info Metadata describing tensor shapes, strides, etc.
* @param output Pointer to output buffer on device.
* @param inputs Vector of input pointers (device memory).
* @param stream CUDA stream (opaque void*).
* @param args Additional operation-specific arguments.
* @return infiniStatus_t Status indicating success or failure.
*/
/* Invoke elementwise operation when all inputs have the same dtype */
template <unsigned int BLOCK_SIZE, typename Op, typename Tdata, typename... Args>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *output,
......
......@@ -18,11 +18,25 @@ public:
~DeviceImpl() = default;
template <typename... Args>
static infiniStatus_t create(
DeviceImpl **device_info,
Args &&...args);
static infiniStatus_t create(DeviceImpl **device_info, Args &&...args);
/* Invoke elementwise operation when all inputs have the same dtype */
/**
* @brief Launches elementwise operation where all input types are the same.
*
* Calls the corresponding templated `calculateImpl` with a unified input type.
*
* @tparam BLOCK_SIZE Number of threads per block.
* @tparam Op Operation functor defining the computation.
* @tparam Tdata Data type for both input and output tensors.
* @tparam Args... Additional arguments passed to the operation.
*
* @param info Metadata describing tensor shapes, strides, etc.
* @param output Pointer to output buffer on device.
* @param inputs Vector of input pointers (device memory).
* @param stream CUDA stream (opaque void*).
* @param args Additional operation-specific arguments.
* @return infiniStatus_t Status indicating success or failure.
*/
template <unsigned int BLOCK_SIZE, typename Op, typename Tdata, typename... Args>
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
......@@ -31,7 +45,23 @@ public:
void *stream,
Args &&...args);
/* Invoke elementwise operation for different input types */
/**
* @brief Launches elementwise operation where input types may differ.
*
* Dispatches to templated `calculateImpl` using specified output and input types.
*
* @tparam BLOCK_SIZE Number of threads per block.
* @tparam Op Operation functor defining the computation.
* @tparam Tout Output data type.
* @tparam Tin... Input data types (must match Op::num_inputs).
* @tparam Args... Additional arguments passed to the operation.
* @param info Metadata describing tensor shapes, strides, etc.
* @param output Pointer to output buffer on device.
* @param inputs Vector of input pointers (device memory).
* @param stream CUDA stream (opaque void*).
* @param args (UNUSED) Additional operation-specific arguments.
* @return infiniStatus_t Status indicating success or failure.
*/
template <unsigned int BLOCK_SIZE, typename Op, typename Tout, typename... Tin,
typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment