Commit c17e3b83 authored by wooway777's avatar wooway777
Browse files

issue/1033 - fix arg with cudart version

parent 9015e384
......@@ -16,10 +16,21 @@ static cudaError argMax_(
void *workspace_ptr,
size_t &workspace_len,
cudaStream_t stream) {
#if CUDART_VERSION >= 11000
// New interface: separate value and index outputs
T* max_value = &kv_pair->value;
int* max_index = &kv_pair->key;
return cub::DeviceReduce::ArgMax(
workspace_ptr, workspace_len,
logits, max_value, max_index, n,
stream);
#else
// Old interface
return cub::DeviceReduce::ArgMax(
workspace_ptr, workspace_len,
logits, kv_pair, n,
stream);
#endif
}
template <class Tval, class Tidx>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment