Commit f8c2af4c authored by yuguo's avatar yuguo
Browse files

Merge commit '1d903f5e' of...

Merge commit '1d903f5e' of https://github.com/NVIDIA/TransformerEngine
parents e92773a3 1d903f5e
...@@ -18,4 +18,7 @@ void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id ...@@ -18,4 +18,7 @@ void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id
const NVTETensor prob, const int num_rows, const int topK, const int num_cols, const NVTETensor prob, const int num_rows, const int topK, const int num_cols,
cudaStream_t stream = nullptr); cudaStream_t stream = nullptr);
void nvte_device_radix_sort_pairs(void *temp_storage, size_t *temp_storage_bytes, int *keys_in,
int *keys_out, int *values_in, int *values_out, size_t num_items);
#endif // TRANSFORMER_ENGINE_PERMUTATION_H_ #endif // TRANSFORMER_ENGINE_PERMUTATION_H_
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment