Commit cdb862cd authored by wenjh's avatar wenjh
Browse files

[DCU] Fix launch bounds



Fix launch bounds of multi_tensor_apply_kernel and
thd_out_correction_kernel.
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent ab9d7598
...@@ -865,7 +865,7 @@ static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_ ...@@ -865,7 +865,7 @@ static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_
} }
constexpr int tile = 16; constexpr int tile = 16;
constexpr int block = 512; constexpr int block = 256;
unsigned int grid_x = unsigned int grid_x =
(static_cast<size_t>(total_tokens) / (only_second_half + 1) * tile + block - 1) / block; (static_cast<size_t>(total_tokens) / (only_second_half + 1) * tile + block - 1) / block;
dim3 grid = {grid_x, (unsigned int)num_heads}; dim3 grid = {grid_x, (unsigned int)num_heads};
......
...@@ -631,7 +631,7 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -631,7 +631,7 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
p_in_type, 0, "adam", p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam", g_in_type, 1, "adam",
multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, multi_tensor_apply<BLOCK_SIZE, 4>((int64_t)chunk_size, noop_flag,
tensor_lists, tensor_lists,
AdamFunctor<scalar_t_0, scalar_t_1, float, int64_t>(), beta1, AdamFunctor<scalar_t_0, scalar_t_1, float, int64_t>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon, lr, beta2, bias_correction1, bias_correction2, epsilon, lr,
...@@ -642,7 +642,7 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -642,7 +642,7 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
p_in_type, 0, "adam", p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam", g_in_type, 1, "adam",
multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, multi_tensor_apply<BLOCK_SIZE, 5>((int64_t)chunk_size, noop_flag,
tensor_lists, tensor_lists,
AdamFunctorMaster<scalar_t_0, scalar_t_1, float, int64_t>(), AdamFunctorMaster<scalar_t_0, scalar_t_1, float, int64_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
...@@ -655,7 +655,7 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -655,7 +655,7 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
p_in_type, 0, "adam", p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam", g_in_type, 1, "adam",
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists,
AdamFunctor<scalar_t_0, scalar_t_1, float, int32_t>(), beta1, AdamFunctor<scalar_t_0, scalar_t_1, float, int32_t>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon, lr, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);)); (adamMode_t)mode, weight_decay);));
...@@ -664,7 +664,7 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -664,7 +664,7 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
p_in_type, 0, "adam", p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam", g_in_type, 1, "adam",
multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 5>(chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<scalar_t_0, scalar_t_1, float, int32_t>(), AdamFunctorMaster<scalar_t_0, scalar_t_1, float, int32_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);)); (adamMode_t)mode, weight_decay);));
...@@ -701,7 +701,7 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag ...@@ -701,7 +701,7 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag
p_in_type, 0, "adam", p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam", g_in_type, 1, "adam",
multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 5>((int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMasterParamRemainder<scalar_t_1, float, int64_t>(), AdamFunctorMasterParamRemainder<scalar_t_1, float, int64_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);)); (adamMode_t)mode, weight_decay);));
...@@ -751,8 +751,8 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -751,8 +751,8 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
fp8_dtype, FP8_T, fp8_dtype, FP8_T,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 0, "adam", g_in_type, 0, "adam",
multi_tensor_apply<5, true>( multi_tensor_apply<BLOCK_SIZE, 5, true>(
(int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<FP8_T, scalar_t_0, float, int64_t>(), beta1, beta2, AdamFunctorMaster<FP8_T, scalar_t_0, float, int64_t>(), beta1, beta2,
bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);)); bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);));
} else { } else {
...@@ -760,7 +760,7 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -760,7 +760,7 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
fp8_dtype, FP8_T, fp8_dtype, FP8_T,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 0, "adam", g_in_type, 0, "adam",
multi_tensor_apply<5, true>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 5, true>(chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<FP8_T, scalar_t_0, float, int32_t>(), AdamFunctorMaster<FP8_T, scalar_t_0, float, int32_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon, beta1, beta2, bias_correction1, bias_correction2, epsilon,
lr, (adamMode_t)mode, weight_decay);)); lr, (adamMode_t)mode, weight_decay);));
...@@ -778,7 +778,7 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -778,7 +778,7 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "adam", tensor_lists[0][0].scalar_type(), 0, "adam",
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists,
AdamCapturableFunctor<scalar_t_0, float>(), beta1, beta2, AdamCapturableFunctor<scalar_t_0, float>(), beta1, beta2,
step.data_ptr<int>(), bias_correction, epsilon, lr.data_ptr<float>(), step.data_ptr<int>(), bias_correction, epsilon, lr.data_ptr<float>(),
(adamMode_t)mode, weight_decay, inv_scale.data_ptr<float>());) (adamMode_t)mode, weight_decay, inv_scale.data_ptr<float>());)
...@@ -796,7 +796,7 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_fl ...@@ -796,7 +796,7 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_fl
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "adam", tensor_lists[0][0].scalar_type(), 0, "adam",
multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 5>(chunk_size, noop_flag, tensor_lists,
AdamCapturableMasterFunctor<scalar_t_0, float>(), beta1, beta2, AdamCapturableMasterFunctor<scalar_t_0, float>(), beta1, beta2,
step.data_ptr<int>(), bias_correction, epsilon, lr.data_ptr<float>(), step.data_ptr<int>(), bias_correction, epsilon, lr.data_ptr<float>(),
(adamMode_t)mode, weight_decay, inv_scale.data_ptr<float>());) (adamMode_t)mode, weight_decay, inv_scale.data_ptr<float>());)
......
...@@ -60,7 +60,7 @@ void multi_tensor_compute_scale_and_scale_inv_cuda( ...@@ -60,7 +60,7 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(
float max_fp8, bool force_pow_2_scales, float epsilon) { float max_fp8, bool force_pow_2_scales, float epsilon) {
using namespace at; using namespace at;
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 3>(chunk_size, noop_flag, tensor_lists,
ComputeScaleAndScaleInvFunctor(), max_fp8, force_pow_2_scales, epsilon); ComputeScaleAndScaleInvFunctor(), max_fp8, force_pow_2_scales, epsilon);
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
} }
...@@ -240,7 +240,7 @@ struct MaxNormFunctor { ...@@ -240,7 +240,7 @@ struct MaxNormFunctor {
} }
}; };
__global__ void cleanup(float *output, float *output_per_tensor, float *ret, float *ret_per_tensor, __global__ void __launch_bounds__(512) cleanup(float *output, float *output_per_tensor, float *ret, float *ret_per_tensor,
bool per_tensor, int max_chunks_per_tensor) { bool per_tensor, int max_chunks_per_tensor) {
__shared__ float vals[512]; __shared__ float vals[512];
...@@ -338,7 +338,7 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda( ...@@ -338,7 +338,7 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
DISPATCH_FLOAT_HALF_AND_BFLOAT( DISPATCH_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 1>(chunk_size, noop_flag, tensor_lists,
L2NormFunctor<scalar_t_0>(), output.data_ptr<float>(), L2NormFunctor<scalar_t_0>(), output.data_ptr<float>(),
per_tensor ? output_per_tensor.data_ptr<float>() : nullptr, per_tensor, per_tensor ? output_per_tensor.data_ptr<float>() : nullptr, per_tensor,
max_chunks_per_tensor);) max_chunks_per_tensor);)
...@@ -388,7 +388,7 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda( ...@@ -388,7 +388,7 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
DISPATCH_FLOAT_HALF_AND_BFLOAT( DISPATCH_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_unscale_l2norm_cuda", tensor_lists[0][0].scalar_type(), 0, "multi_tensor_unscale_l2norm_cuda",
multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 1>(chunk_size, noop_flag, tensor_lists,
UnscaleL2NormFunctor<scalar_t_0>(), inv_scale.data_ptr<float>(), UnscaleL2NormFunctor<scalar_t_0>(), inv_scale.data_ptr<float>(),
output.data_ptr<float>(), output.data_ptr<float>(),
per_tensor ? output_per_tensor.data_ptr<float>() : nullptr, per_tensor, per_tensor ? output_per_tensor.data_ptr<float>() : nullptr, per_tensor,
......
...@@ -112,7 +112,7 @@ void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -112,7 +112,7 @@ void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda", tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda",
DISPATCH_FLOAT_HALF_AND_BFLOAT( DISPATCH_FLOAT_HALF_AND_BFLOAT(
tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda", tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda",
multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 2>(chunk_size, noop_flag, tensor_lists,
ScaleFunctor<scalar_t_0, scalar_t_1>(), scale);)) ScaleFunctor<scalar_t_0, scalar_t_1>(), scale);))
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
......
...@@ -152,7 +152,7 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -152,7 +152,7 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
// Case 1. fp16, fp16, fp16, No // Case 1. fp16, fp16, fp16, No
if (grad_type == at::ScalarType::Half && weight_type == at::ScalarType::Half && if (grad_type == at::ScalarType::Half && weight_type == at::ScalarType::Half &&
num_tensors == 3) { num_tensors == 3) {
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 3>(chunk_size, noop_flag, tensor_lists,
SGDFunctor<3, at::Half, at::Half>(), wd, momentum, dampening, lr, SGDFunctor<3, at::Half, at::Half>(), wd, momentum, dampening, lr,
nesterov, first_run, wd_after_momentum, scale); nesterov, first_run, wd_after_momentum, scale);
} }
...@@ -160,8 +160,7 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -160,8 +160,7 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
// else if (grad_type == at::ScalarType::Half && // else if (grad_type == at::ScalarType::Half &&
// weight_type == at::ScalarType::Float && // weight_type == at::ScalarType::Float &&
// num_tensors == 3) { // num_tensors == 3) {
// multi_tensor_apply<3>( // multi_tensor_apply<BLOCK_SIZE, 3>(
// BLOCK_SIZE,
// chunk_size, // chunk_size,
// noop_flag, // noop_flag,
// tensor_lists, // tensor_lists,
...@@ -177,21 +176,21 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -177,21 +176,21 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
// Case 2. fp32, fp32, fp32, No // Case 2. fp32, fp32, fp32, No
else if (grad_type == at::ScalarType::Float && // NOLINT(*) else if (grad_type == at::ScalarType::Float && // NOLINT(*)
weight_type == at::ScalarType::Float && num_tensors == 3) { weight_type == at::ScalarType::Float && num_tensors == 3) {
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 3>(chunk_size, noop_flag, tensor_lists,
SGDFunctor<3, float, float>(), wd, momentum, dampening, lr, nesterov, SGDFunctor<3, float, float>(), wd, momentum, dampening, lr, nesterov,
first_run, wd_after_momentum, scale); first_run, wd_after_momentum, scale);
} }
// Case 3. fp16, fp32, fp32, Yes // Case 3. fp16, fp32, fp32, Yes
else if (grad_type == at::ScalarType::Half && // NOLINT(*) else if (grad_type == at::ScalarType::Half && // NOLINT(*)
weight_type == at::ScalarType::Float && num_tensors == 4) { weight_type == at::ScalarType::Float && num_tensors == 4) {
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists,
SGDFunctor<4, at::Half, float>(), wd, momentum, dampening, lr, nesterov, SGDFunctor<4, at::Half, float>(), wd, momentum, dampening, lr, nesterov,
first_run, wd_after_momentum, scale); first_run, wd_after_momentum, scale);
} }
// Case 4. fp32, fp32, fp32, Yes // Case 4. fp32, fp32, fp32, Yes
else if (grad_type == at::ScalarType::Float && // NOLINT(*) else if (grad_type == at::ScalarType::Float && // NOLINT(*)
weight_type == at::ScalarType::Float && num_tensors == 4) { weight_type == at::ScalarType::Float && num_tensors == 4) {
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists,
SGDFunctor<4, float, float>(), wd, momentum, dampening, lr, nesterov, SGDFunctor<4, float, float>(), wd, momentum, dampening, lr, nesterov,
first_run, wd_after_momentum, scale); first_run, wd_after_momentum, scale);
} else { } else {
......
...@@ -37,16 +37,16 @@ struct TensorListMetadata<n, true> : public TensorListMetadataBase<n, true> { ...@@ -37,16 +37,16 @@ struct TensorListMetadata<n, true> : public TensorListMetadataBase<n, true> {
void *fp8_meta_addresses[3][depth_to_max_tensors[n - 1]]; void *fp8_meta_addresses[3][depth_to_max_tensors[n - 1]];
}; };
template <typename T, typename U, typename... ArgTypes> template <int64_t block_size, typename T, typename U, typename... ArgTypes>
__global__ void multi_tensor_apply_kernel(int64_t chunk_size, volatile int *noop_flag, T tl, __global__ void __launch_bounds__(block_size) multi_tensor_apply_kernel(int64_t chunk_size, volatile int *noop_flag, T tl,
U callable, ArgTypes... args) { U callable, ArgTypes... args) {
// Hand the chunk information to the user-supplied functor to process however // Hand the chunk information to the user-supplied functor to process however
// it likes. // it likes.
callable(chunk_size, noop_flag, tl, args...); callable(chunk_size, noop_flag, tl, args...);
} }
template <int depth, bool USE_FP8 = false, typename T, typename... ArgTypes> template <int64_t block_size, int depth, bool USE_FP8 = false, typename T, typename... ArgTypes>
void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor &noop_flag, void multi_tensor_apply(int64_t chunk_size, const at::Tensor &noop_flag,
const std::vector<std::vector<at::Tensor>> &tensor_lists, T callable, const std::vector<std::vector<at::Tensor>> &tensor_lists, T callable,
ArgTypes... args) { ArgTypes... args) {
if constexpr (USE_FP8) { if constexpr (USE_FP8) {
...@@ -90,6 +90,7 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor ...@@ -90,6 +90,7 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor
tl.start_tensor_this_launch = 0; tl.start_tensor_this_launch = 0;
int loc_block_info = 0; int loc_block_info = 0;
int loc_tensor_info = 0; int loc_tensor_info = 0;
auto kernel = &multi_tensor_apply_kernel<block_size, TensorListMetadata<depth, USE_FP8>, T, ArgTypes...>;
for (int t = 0; t < ntensors; t++) { for (int t = 0; t < ntensors; t++) {
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
for (int d = 0; d < depth; d++) for (int d = 0; d < depth; d++)
...@@ -112,7 +113,7 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor ...@@ -112,7 +113,7 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]); bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1); bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
if (tensors_full || blocks_full || last_chunk) { if (tensors_full || blocks_full || last_chunk) {
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>( kernel<<<loc_block_info, block_size, 0, stream>>>(
chunk_size, noop_flag.data_ptr<int>(), tl, callable, args...); chunk_size, noop_flag.data_ptr<int>(), tl, callable, args...);
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment