Commit 3befaca2 authored by sangwzh's avatar sangwzh
Browse files

update warpsize to 64

parent 910cec0c
......@@ -20,7 +20,7 @@ namespace cuda {
class GpuCache : public torch::CustomClassHolder {
using key_t = long long;
constexpr static int set_associativity = 2;
constexpr static int WARP_SIZE = 32;
constexpr static int WARP_SIZE = 64;
constexpr static int bucket_size = WARP_SIZE * set_associativity;
using gpu_cache_t = ::gpu_cache::gpu_cache<
key_t, uint64_t, std::numeric_limits<key_t>::max(), set_associativity,
......
......@@ -34,7 +34,8 @@ NDArray IndexSelect(NDArray array, IdArray index) {
DType* ret_data = static_cast<DType*>(ret->data);
const DType* array_data = static_cast<DType*>(cuda::GetDevicePointer(array));
const IdType* idx_data = static_cast<IdType*>(index->data);
// const IdType* idx_data = static_cast<IdType*>(index->data);
const IdType* idx_data = static_cast<IdType*>(cuda::GetDevicePointer(index));
hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
if (num_feat == 1) {
......
......@@ -41,7 +41,7 @@ namespace cuda {
template <typename key_t>
class GpuCache : public runtime::Object {
constexpr static int set_associativity = 2;
constexpr static int WARP_SIZE = 32;
constexpr static int WARP_SIZE = 64;
constexpr static int bucket_size = WARP_SIZE * set_associativity;
using gpu_cache_t = gpu_cache::gpu_cache<
key_t, uint64_t, std::numeric_limits<key_t>::max(), set_associativity,
......
......@@ -30,7 +30,7 @@
#endif
#define SET_ASSOCIATIVITY 2
#define SLAB_SIZE 32
#define SLAB_SIZE 64
#define TASK_PER_WARP_TILE_MACRO 1
namespace gpu_cache {
......
......@@ -1251,7 +1251,7 @@ gpu_cache<key_type, ref_counter_type, empty_key, set_associativity, warp_size, s
return;
}
if (warp_size != 1 && warp_size != 2 && warp_size != 4 && warp_size != 8 && warp_size != 16 &&
warp_size != 32) {
warp_size != 32 &&warp_size != 64) {
printf("Error: Invalid value for warp_size.\n");
return;
}
......@@ -1299,7 +1299,7 @@ gpu_cache<key_type, ref_counter_type, empty_key, set_associativity, warp_size, s
return;
}
if (warp_size != 1 && warp_size != 2 && warp_size != 4 && warp_size != 8 && warp_size != 16 &&
warp_size != 32) {
warp_size != 32 && warp_size != 64) {
printf("Error: Invalid value for warp_size.\n");
return;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment