"tests/cpp/operator/test_causal_softmax.cu" did not exist on "94de051f65d9220303d0a42a97dd28638695212e"
Unverified Commit e5a673f6 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[C][PyTorch] Move cuda kernels from pytorch extensions to core part 1 (#1702)



* Move radix sort to core
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix; change fused_attn to include C header
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix args
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 94bff099
...@@ -11,8 +11,7 @@ ...@@ -11,8 +11,7 @@
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_ #ifndef TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_ #define TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_
#include <cstdint> #include "stdint.h"
#include "transformer_engine.h" #include "transformer_engine.h"
#ifdef __cplusplus #ifdef __cplusplus
......
...@@ -18,4 +18,7 @@ void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id ...@@ -18,4 +18,7 @@ void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id
const NVTETensor prob, const int num_rows, const int topK, const int num_cols, const NVTETensor prob, const int num_rows, const int topK, const int num_cols,
cudaStream_t stream = nullptr); cudaStream_t stream = nullptr);
void nvte_device_radix_sort_pairs(void *temp_storage, size_t *temp_storage_bytes, int *keys_in,
int *keys_out, int *values_in, int *values_out, size_t num_items);
#endif // TRANSFORMER_ENGINE_PERMUTATION_H_ #endif // TRANSFORMER_ENGINE_PERMUTATION_H_
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#include <transformer_engine/permutation.h> #include <transformer_engine/permutation.h>
#include <cub/cub.cuh>
#include "../common.h" #include "../common.h"
static __global__ void moe_permute_row_map(const int *sorted_row_id, int *row_id_map, static __global__ void moe_permute_row_map(const int *sorted_row_id, int *row_id_map,
...@@ -367,3 +369,11 @@ void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id ...@@ -367,3 +369,11 @@ void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id
reinterpret_cast<const float *>(prob_cu->data.dptr), num_rows, topK, reinterpret_cast<const float *>(prob_cu->data.dptr), num_rows, topK,
num_cols, stream);); num_cols, stream););
} }
void nvte_device_radix_sort_pairs(void *temp_storage, size_t *temp_storage_bytes, int *keys_in,
int *keys_out, int *values_in, int *values_out,
size_t num_items) {
NVTE_API_CALL(nvte_device_radix_sort_pairs);
cub::DeviceRadixSort::SortPairs(temp_storage, *temp_storage_bytes, keys_in, keys_out, values_in,
values_out, num_items);
}
...@@ -4,8 +4,6 @@ ...@@ -4,8 +4,6 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <cub/cub.cuh>
#include "extensions.h" #include "extensions.h"
std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd( std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
...@@ -28,9 +26,8 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd( ...@@ -28,9 +26,8 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
int *temp_ptr = nullptr; nvte_device_radix_sort_pairs(nullptr, &temp_storage_bytes, nullptr, nullptr, nullptr, nullptr,
cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_ptr, temp_ptr, temp_ptr, max_expanded_token_num);
temp_ptr, max_expanded_token_num);
at::Tensor temp_storage = torch::empty( at::Tensor temp_storage = torch::empty(
temp_storage_bytes, torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); temp_storage_bytes, torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
...@@ -40,17 +37,18 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd( ...@@ -40,17 +37,18 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
workspace.push_back(temp_storage); workspace.push_back(temp_storage);
} }
int *indices_ptr = reinterpret_cast<int *>(getDataPtr(indices, 0)); void *indices_ptr = getDataPtr(indices, 0);
int *sorted_indices_ptr = reinterpret_cast<int *>(getDataPtr(workspace[0], 0)); void *sorted_indices_ptr = getDataPtr(workspace[0], 0);
int *row_id_ptr = reinterpret_cast<int *>(getDataPtr(workspace[1], 0)); void *row_id_ptr = getDataPtr(workspace[1], 0);
int *sorted_row_id_ptr = reinterpret_cast<int *>(getDataPtr(workspace[2], 0)); void *sorted_row_id_ptr = getDataPtr(workspace[2], 0);
void *d_temp_storage = getDataPtr(workspace[3], 0); void *d_temp_storage = getDataPtr(workspace[3], 0);
size_t temp_storage_bytes = std::numeric_limits<size_t>::max(); size_t temp_storage_bytes = std::numeric_limits<size_t>::max();
cub::DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, indices_ptr, nvte_device_radix_sort_pairs(
sorted_indices_ptr, row_id_ptr, sorted_row_id_ptr, d_temp_storage, &temp_storage_bytes, reinterpret_cast<int *>(indices_ptr),
num_tokens * topK); reinterpret_cast<int *>(sorted_indices_ptr), reinterpret_cast<int *>(row_id_ptr),
reinterpret_cast<int *>(sorted_row_id_ptr), num_tokens * topK);
// Output buffer alloc // Output buffer alloc
num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK; num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment