Commit 8f7de847 authored by yuguo960516yuguo's avatar yuguo960516yuguo
Browse files

dtk

parent f262efc9
Pipeline #248 failed with stages
in 0 seconds
# OneFlow
# OneFlow
OneFlow is a deep learning framework designed to be **user-friendly, scalable and efficient**. With OneFlow, it is easy to:
- program a model with **PyTorch-like API**
- scale a model to n-dimensional-parallel/distributed execution with the **Global View API**
- accelerate/deploy a model with the **Static Graph Compiler**.
**OneFlow is a performance-centered and open-source deep learning framework.**
[![Simple CI](https://github.com/Oneflow-Inc/oneflow/actions/workflows/simple.yml/badge.svg)](https://github.com/Oneflow-Inc/oneflow/actions/workflows/simple.yml)
[![Nightly Docker Image](https://github.com/Oneflow-Inc/docker-images/actions/workflows/oneflow-nightly.yml/badge.svg)](https://github.com/Oneflow-Inc/docker-images/actions/workflows/oneflow-nightly.yml)
......@@ -12,8 +9,10 @@ OneFlow is a deep learning framework designed to be **user-friendly, scalable an
## Latest News
- Version 0.8.0 is out!
- [Full changelog](https://github.com/Oneflow-Inc/oneflow/releases/tag/v0.8.0)
- Version 0.7.0 is out!
- Introducing global tensor
- Semi-auto parallelization has landed
- [Full changelog](https://github.com/Oneflow-Inc/oneflow/releases/tag/v0.7.0)
## Publication
......@@ -36,7 +35,7 @@ OneFlow is a deep learning framework designed to be **user-friendly, scalable an
### System Requirements
- Linux. As for now, there is no pre-built release for macOS, Windows.
- Python 3.7, 3.8, 3.9, 3.10
- Python 3.6, 3.7, 3.8, 3.9, 3.10
- (**Highly recommended**) Upgrade pip
```
......@@ -54,7 +53,7 @@ OneFlow is a deep learning framework designed to be **user-friendly, scalable an
- To install latest stable release of OneFlow with CUDA support:
```bash
python3 -m pip install oneflow
python3 -m pip install -f https://release.oneflow.info oneflow==0.7.0+cu102
```
- To install nightly release of OneFlow with CUDA support:
......@@ -67,7 +66,7 @@ OneFlow is a deep learning framework designed to be **user-friendly, scalable an
- Stable
```bash
python3 -m pip install --find-links https://release.oneflow.info oneflow==0.8.0+[PLATFORM]
python3 -m pip install --find-links https://release.oneflow.info oneflow==0.7.0+[PLATFORM]
```
- Nightly
```
......
# Monkey patch to not ship libjvm.so in pypi wheels
import sys
from auditwheel.main import main
from auditwheel.policy import _POLICIES as POLICIES
# libjvm is loaded dynamically; do not include it
for p in POLICIES:
p['lib_whitelist'].append('librccl.so.1')
p['lib_whitelist'].append('libhipblas.so.0')
p['lib_whitelist'].append('libhiprand.so.1')
p['lib_whitelist'].append('librocrand.so.1')
p['lib_whitelist'].append('libMIOpen.so.1')
p['lib_whitelist'].append('libgalaxyhip.so.4')
p['lib_whitelist'].append('librocm_smi64.so.2')
p['lib_whitelist'].append('librocsolver.so.0 ')
p['lib_whitelist'].append('librocblas.so.0')
if __name__ == "__main__":
sys.exit(main())
# Monkey patch to not ship libjvm.so in pypi wheels
import sys
from auditwheel.main import main
from auditwheel.policy import _POLICIES as POLICIES
# libjvm is loaded dynamically; do not include it
for p in POLICIES:
p['lib_whitelist'].append('librccl.so.1')
p['lib_whitelist'].append('libhipblas.so.0')
p['lib_whitelist'].append('libhiprand.so.1')
p['lib_whitelist'].append('librocrand.so.1')
p['lib_whitelist'].append('libMIOpen.so.1')
p['lib_whitelist'].append('libgalaxyhip.so.5')
p['lib_whitelist'].append('librocm_smi64.so.2')
p['lib_whitelist'].append('librocsolver.so.0 ')
p['lib_whitelist'].append('librocblas.so.0')
if __name__ == "__main__":
sys.exit(main())
......@@ -328,6 +328,17 @@ if(BUILD_PYTHON OR BUILD_CPP_API)
endif()
endif()
if (BUILD_ROCM)
# AMD compiler fails to compile these three files with '-O1/2/3'.
# The value of `COMPILE_OPTIONS` target property is added after CMAKE_<LANG>_FLAGS_<CONFIG>,
# so '-O0' will override '-O1/2/3'.
set_source_files_properties(${PROJECT_SOURCE_DIR}/oneflow/user/kernels/median_with_indices_kernel.hip.cpp
${PROJECT_SOURCE_DIR}/oneflow/user/kernels/radix_sort_top_k_kernel.hip.cpp
${PROJECT_SOURCE_DIR}/oneflow/user/kernels/arg_sort_kernel.hip.cpp
# ${PROJECT_SOURCE_DIR}/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math.hip.cpp
PROPERTIES COMPILE_OPTIONS "-O0")
endif()
if(BUILD_PYTHON)
# py ext lib
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "hip/hip_runtime.h"
#include "oneflow/core/embedding/cached_key_value_store.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/ep/include/device_manager_registry.h"
namespace oneflow {
namespace embedding {
namespace {
template<typename Key, typename Elem>
__global__ void PostStoreGetKernel(uint32_t num_cache_missing, uint32_t num_store_missing,
uint32_t num_elems_per_value,
const uint32_t* cache_missing_indices,
const uint32_t* store_missing_indices, const Elem* store_values,
Elem* values, uint32_t* missing_indices) {
const uint32_t num_cache_missing_elem = num_cache_missing * num_elems_per_value;
CUDA_1D_KERNEL_LOOP_T(uint32_t, i, num_cache_missing_elem) {
const uint32_t value_index = i / num_elems_per_value;
const uint32_t elem_index = i - value_index * num_elems_per_value;
values[cache_missing_indices[value_index] * num_elems_per_value + elem_index] = store_values[i];
}
CUDA_1D_KERNEL_LOOP_T(uint32_t, i, num_store_missing) {
missing_indices[i] = cache_missing_indices[store_missing_indices[i]];
}
}
template<typename Key, typename Elem>
class CacheKeyValueStoreImpl : public KeyValueStore {
public:
OF_DISALLOW_COPY_AND_MOVE(CacheKeyValueStoreImpl);
CacheKeyValueStoreImpl(std::unique_ptr<KeyValueStore>&& store, std::unique_ptr<Cache>&& cache)
: store_(std::move(store)), cache_(std::move(cache)), synced_(true), max_query_length_(0) {
OF_CUDA_CHECK(hipGetDevice(&device_index_));
CHECK_EQ(store_->KeySize(), cache_->KeySize());
CHECK_EQ(store_->ValueSize(), cache_->ValueSize());
OF_CUDA_CHECK(hipMalloc(&num_buffer_, sizeof(uint32_t)));
OF_CUDA_CHECK(hipMallocHost(reinterpret_cast<void **>(&host_num_buffer_), sizeof(uint32_t)));
num_elems_per_value_ = store_->ValueSize() / sizeof(Elem);
}
~CacheKeyValueStoreImpl() {
CudaCurrentDeviceGuard guard(device_index_);
OF_CUDA_CHECK(hipFree(num_buffer_));
OF_CUDA_CHECK(hipHostFree(host_num_buffer_));
if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipFree(keys_buffer_));
OF_CUDA_CHECK(hipFree(values_buffer_));
OF_CUDA_CHECK(hipFree(indices_buffer0_));
OF_CUDA_CHECK(hipFree(indices_buffer1_));
}
cache_.reset();
store_.reset();
}
uint32_t KeySize() const override { return store_->KeySize(); }
uint32_t ValueSize() const override { return store_->ValueSize(); }
uint32_t MaxQueryLength() const override { return max_query_length_; }
void ReserveQueryLength(uint32_t query_length) override {
CudaCurrentDeviceGuard guard(device_index_);
if (query_length <= max_query_length_) { return; }
if (query_length > cache_->MaxQueryLength()) { cache_->ReserveQueryLength(query_length); }
if (query_length > store_->MaxQueryLength()) { store_->ReserveQueryLength(query_length); }
if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipFree(keys_buffer_));
OF_CUDA_CHECK(hipFree(values_buffer_));
OF_CUDA_CHECK(hipFree(indices_buffer0_));
OF_CUDA_CHECK(hipFree(indices_buffer1_));
}
OF_CUDA_CHECK(hipMalloc(&keys_buffer_, query_length * store_->KeySize()));
OF_CUDA_CHECK(hipMalloc(&values_buffer_, query_length * store_->ValueSize()));
OF_CUDA_CHECK(hipMalloc(&indices_buffer0_, query_length * sizeof(uint32_t)));
OF_CUDA_CHECK(hipMalloc(&indices_buffer1_, query_length * sizeof(uint32_t)));
max_query_length_ = query_length;
}
void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values,
uint32_t* n_missing, uint32_t* missing_indices) override;
void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values,
uint8_t* mask) override;
void Put(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values) override;
void FusedHalfUpdatePut(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values,
const void* update, const float* lr, float scale) override;
bool IsFusionSupported() override {
return cache_->Policy() == CacheOptions::Policy::kFull
&& cache_->ValueType() == DataType::kFloat;
}
bool SnapshotExists(const std::string& name) override;
void LoadSnapshot(const std::string& name) override;
void SaveSnapshot(const std::string& name) override;
void LoadSnapshot(const std::string& name,
const std::function<void(KVIterator* iter)>& Hook) override;
private:
void SyncCacheToStore();
std::unique_ptr<KeyValueStore> store_;
std::unique_ptr<Cache> cache_;
uint32_t* num_buffer_{};
uint32_t* host_num_buffer_{};
Key* keys_buffer_{};
Elem* values_buffer_{};
uint32_t* indices_buffer0_{};
uint32_t* indices_buffer1_{};
int device_index_{};
uint32_t max_query_length_;
uint32_t num_elems_per_value_{};
std::recursive_mutex mutex_;
bool synced_;
};
template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::Get(ep::Stream* stream, uint32_t num_keys, const void* keys,
void* values, uint32_t* n_missing,
uint32_t* missing_indices) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
auto cuda_stream = stream->As<ep::CudaStream>();
if (cache_->Policy() == CacheOptions::Policy::kFull) {
cache_->Get(stream, num_keys, keys, values, n_missing, keys_buffer_, missing_indices);
return;
} else {
cache_->Get(stream, num_keys, keys, values, num_buffer_, keys_buffer_, indices_buffer0_);
}
OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync());
const uint32_t num_cache_missing = *host_num_buffer_;
if (num_cache_missing == 0) {
OF_CUDA_CHECK(hipMemsetAsync(n_missing, 0, sizeof(uint32_t),
stream->As<ep::CudaStream>()->cuda_stream()));
return;
}
store_->Get(stream, num_cache_missing, keys_buffer_, values_buffer_, n_missing, indices_buffer1_);
OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, n_missing, sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync());
const uint32_t num_store_missing = *host_num_buffer_;
RUN_CUDA_KERNEL((PostStoreGetKernel<Key, Elem>), stream, num_cache_missing * num_elems_per_value_,
num_cache_missing, num_store_missing, num_elems_per_value_, indices_buffer0_,
indices_buffer1_, values_buffer_, static_cast<Elem*>(values), missing_indices);
}
template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::Get(ep::Stream* stream, uint32_t num_keys, const void* keys,
void* values, uint8_t* mask) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
if (cache_->Policy() == CacheOptions::Policy::kFull) {
cache_->Get(stream, num_keys, keys, values, mask);
return;
} else {
UNIMPLEMENTED();
}
}
template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::Put(ep::Stream* stream, uint32_t num_keys, const void* keys,
const void* values) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
synced_ = false;
auto cuda_stream = stream->As<ep::CudaStream>();
cache_->Put(stream, num_keys, keys, values, num_buffer_, keys_buffer_, values_buffer_);
if (cache_->Policy() == CacheOptions::Policy::kFull) { return; }
OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync());
store_->Put(stream, *host_num_buffer_, keys_buffer_, values_buffer_);
}
template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::FusedHalfUpdatePut(ep::Stream* stream, uint32_t num_keys,
const void* keys, const void* values,
const void* update, const float* lr,
float scale) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
if (cache_->Policy() != CacheOptions::Policy::kFull || cache_->ValueType() != DataType::kFloat) {
UNIMPLEMENTED();
}
synced_ = false;
cache_->FusedHalfUpdatePut(stream, num_keys, keys, values, update, lr, scale, num_buffer_,
keys_buffer_, values_buffer_);
}
template<typename Key, typename Elem>
bool CacheKeyValueStoreImpl<Key, Elem>::SnapshotExists(const std::string& name) {
return store_->SnapshotExists(name);
}
template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::LoadSnapshot(const std::string& name) {
LoadSnapshot(name, nullptr);
}
template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::LoadSnapshot(
const std::string& name, const std::function<void(KVIterator* iter)>& Hook) {
CudaCurrentDeviceGuard guard(device_index_);
std::lock_guard<std::recursive_mutex> lock(mutex_);
CHECK_GT(max_query_length_, 0);
cache_->Clear();
auto device =
Singleton<ep::DeviceManagerRegistry>::Get()->GetDevice(DeviceType::kCUDA, device_index_);
CHECK(device);
auto* stream = device->CreateStream();
store_->LoadSnapshot(name, [&](KVIterator* iter) {
if (cache_->Policy() == CacheOptions::Policy::kFull) {
auto* cuda_stream = stream->As<ep::CudaStream>();
while (true) {
iter->NextN(stream, max_query_length_, num_buffer_, keys_buffer_, values_buffer_);
OF_CUDA_CHECK(hipDeviceSynchronize());
OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t),
hipMemcpyDefault, cuda_stream->cuda_stream()));
CHECK_JUST(stream->Sync());
if (*host_num_buffer_ == 0) { return; }
cache_->Put(stream, *host_num_buffer_, keys_buffer_, values_buffer_, num_buffer_, nullptr,
nullptr);
OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t),
hipMemcpyDefault, cuda_stream->cuda_stream()));
CHECK_JUST(stream->Sync());
CHECK_EQ(*host_num_buffer_, 0);
}
}
if (Hook) {
iter->Reset();
Hook(iter);
}
});
device->DestroyStream(stream);
store_->LoadSnapshot(name);
}
template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::SaveSnapshot(const std::string& name) {
CudaCurrentDeviceGuard guard(device_index_);
std::lock_guard<std::recursive_mutex> lock(mutex_);
SyncCacheToStore();
store_->SaveSnapshot(name);
}
template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::SyncCacheToStore() {
if (synced_) { return; }
CudaCurrentDeviceGuard guard(device_index_);
auto device =
Singleton<ep::DeviceManagerRegistry>::Get()->GetDevice(DeviceType::kCUDA, device_index_);
CHECK(device);
auto* stream = device->CreateStream();
auto* cuda_stream = stream->As<ep::CudaStream>();
const uint64_t dump_capacity = cache_->DumpCapacity();
CHECK_GT(max_query_length_, 0);
for (uint64_t start_key_index = 0; start_key_index < dump_capacity;
start_key_index += max_query_length_) {
cache_->Dump(stream, start_key_index,
std::min(start_key_index + max_query_length_, dump_capacity), num_buffer_,
keys_buffer_, values_buffer_);
OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t),
hipMemcpyDefault, cuda_stream->cuda_stream()));
CHECK_JUST(stream->Sync());
if (*host_num_buffer_ == 0) { continue; }
store_->Put(stream, *host_num_buffer_, keys_buffer_, values_buffer_);
CHECK_JUST(stream->Sync());
}
device->DestroyStream(stream);
synced_ = true;
}
template<typename Key>
std::unique_ptr<KeyValueStore> DispatchElemType(std::unique_ptr<KeyValueStore>&& store,
std::unique_ptr<Cache>&& cache) {
const uint32_t value_size = store->ValueSize();
if (value_size % sizeof(uint4) == 0) {
return std::unique_ptr<KeyValueStore>(
new CacheKeyValueStoreImpl<Key, uint4>(std::move(store), std::move(cache)));
} else if (value_size % sizeof(uint64_t) == 0) {
return std::unique_ptr<KeyValueStore>(
new CacheKeyValueStoreImpl<Key, uint64_t>(std::move(store), std::move(cache)));
} else if (value_size % sizeof(uint32_t) == 0) {
return std::unique_ptr<KeyValueStore>(
new CacheKeyValueStoreImpl<Key, uint32_t>(std::move(store), std::move(cache)));
} else if (value_size % sizeof(uint16_t) == 0) {
return std::unique_ptr<KeyValueStore>(
new CacheKeyValueStoreImpl<Key, uint16_t>(std::move(store), std::move(cache)));
} else {
return std::unique_ptr<KeyValueStore>(
new CacheKeyValueStoreImpl<Key, uint8_t>(std::move(store), std::move(cache)));
}
}
std::unique_ptr<KeyValueStore> DispatchKeyType(std::unique_ptr<KeyValueStore>&& store,
std::unique_ptr<Cache>&& cache) {
const uint32_t key_size = store->KeySize();
if (key_size == 4) {
return DispatchElemType<uint32_t>(std::move(store), std::move(cache));
} else if (key_size == 8) {
return DispatchElemType<uint64_t>(std::move(store), std::move(cache));
} else {
UNIMPLEMENTED();
return nullptr;
}
}
} // namespace
std::unique_ptr<KeyValueStore> NewCachedKeyValueStore(std::unique_ptr<KeyValueStore>&& store,
std::unique_ptr<Cache>&& cache) {
return DispatchKeyType(std::move(store), std::move(cache));
}
} // namespace embedding
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "hip/hip_runtime.h"
#include "oneflow/core/embedding/cached_key_value_store.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/ep/include/device_manager_registry.h"
namespace oneflow {
namespace embedding {
namespace {
template<typename Key, typename Elem>
__global__ void PostStoreGetKernel(uint32_t num_cache_missing, uint32_t num_store_missing,
uint32_t num_elems_per_value,
const uint32_t* cache_missing_indices,
const uint32_t* store_missing_indices, const Elem* store_values,
Elem* values, uint32_t* missing_indices) {
const uint32_t num_cache_missing_elem = num_cache_missing * num_elems_per_value;
CUDA_1D_KERNEL_LOOP_T(uint32_t, i, num_cache_missing_elem) {
const uint32_t value_index = i / num_elems_per_value;
const uint32_t elem_index = i - value_index * num_elems_per_value;
values[cache_missing_indices[value_index] * num_elems_per_value + elem_index] = store_values[i];
}
CUDA_1D_KERNEL_LOOP_T(uint32_t, i, num_store_missing) {
missing_indices[i] = cache_missing_indices[store_missing_indices[i]];
}
}
template<typename Key, typename Elem>
class CacheKeyValueStoreImpl : public KeyValueStore {
public:
OF_DISALLOW_COPY_AND_MOVE(CacheKeyValueStoreImpl);
CacheKeyValueStoreImpl(std::unique_ptr<KeyValueStore>&& store, std::unique_ptr<Cache>&& cache)
: store_(std::move(store)), cache_(std::move(cache)), synced_(true), max_query_length_(0) {
OF_CUDA_CHECK(hipGetDevice(&device_index_));
CHECK_EQ(store_->KeySize(), cache_->KeySize());
CHECK_EQ(store_->ValueSize(), cache_->ValueSize());
OF_CUDA_CHECK(hipMalloc(&num_buffer_, sizeof(uint32_t)));
OF_CUDA_CHECK(hipMallocHost(reinterpret_cast<void **>(&host_num_buffer_), sizeof(uint32_t)));
num_elems_per_value_ = store_->ValueSize() / sizeof(Elem);
}
~CacheKeyValueStoreImpl() {
CudaCurrentDeviceGuard guard(device_index_);
OF_CUDA_CHECK(hipFree(num_buffer_));
OF_CUDA_CHECK(hipHostFree(host_num_buffer_));
if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipFree(keys_buffer_));
OF_CUDA_CHECK(hipFree(values_buffer_));
OF_CUDA_CHECK(hipFree(indices_buffer0_));
OF_CUDA_CHECK(hipFree(indices_buffer1_));
}
cache_.reset();
store_.reset();
}
uint32_t KeySize() const override { return store_->KeySize(); }
uint32_t ValueSize() const override { return store_->ValueSize(); }
uint32_t MaxQueryLength() const override { return max_query_length_; }
void ReserveQueryLength(uint32_t query_length) override {
CudaCurrentDeviceGuard guard(device_index_);
if (query_length <= max_query_length_) { return; }
if (query_length > cache_->MaxQueryLength()) { cache_->ReserveQueryLength(query_length); }
if (query_length > store_->MaxQueryLength()) { store_->ReserveQueryLength(query_length); }
if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipFree(keys_buffer_));
OF_CUDA_CHECK(hipFree(values_buffer_));
OF_CUDA_CHECK(hipFree(indices_buffer0_));
OF_CUDA_CHECK(hipFree(indices_buffer1_));
}
OF_CUDA_CHECK(hipMalloc(&keys_buffer_, query_length * store_->KeySize()));
OF_CUDA_CHECK(hipMalloc(&values_buffer_, query_length * store_->ValueSize()));
OF_CUDA_CHECK(hipMalloc(&indices_buffer0_, query_length * sizeof(uint32_t)));
OF_CUDA_CHECK(hipMalloc(&indices_buffer1_, query_length * sizeof(uint32_t)));
max_query_length_ = query_length;
}
void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values,
uint32_t* n_missing, uint32_t* missing_indices) override;
void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values,
uint8_t* mask) override;
void Put(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values) override;
void FusedHalfUpdatePut(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values,
const void* update, const float* lr, float scale) override;
bool IsFusionSupported() override {
return cache_->Policy() == CacheOptions::Policy::kFull
&& cache_->ValueType() == DataType::kFloat;
}
bool SnapshotExists(const std::string& name) override;
void LoadSnapshot(const std::string& name) override;
void SaveSnapshot(const std::string& name) override;
void LoadSnapshot(const std::string& name,
const std::function<void(KVIterator* iter)>& Hook) override;
private:
void SyncCacheToStore();
std::unique_ptr<KeyValueStore> store_;
std::unique_ptr<Cache> cache_;
uint32_t* num_buffer_{};
uint32_t* host_num_buffer_{};
Key* keys_buffer_{};
Elem* values_buffer_{};
uint32_t* indices_buffer0_{};
uint32_t* indices_buffer1_{};
int device_index_{};
uint32_t max_query_length_;
uint32_t num_elems_per_value_{};
std::recursive_mutex mutex_;
bool synced_;
};
template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::Get(ep::Stream* stream, uint32_t num_keys, const void* keys,
void* values, uint32_t* n_missing,
uint32_t* missing_indices) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
auto cuda_stream = stream->As<ep::CudaStream>();
if (cache_->Policy() == CacheOptions::Policy::kFull) {
cache_->Get(stream, num_keys, keys, values, n_missing, keys_buffer_, missing_indices);
return;
} else {
cache_->Get(stream, num_keys, keys, values, num_buffer_, keys_buffer_, indices_buffer0_);
}
OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync());
const uint32_t num_cache_missing = *host_num_buffer_;
if (num_cache_missing == 0) {
OF_CUDA_CHECK(hipMemsetAsync(n_missing, 0, sizeof(uint32_t),
stream->As<ep::CudaStream>()->cuda_stream()));
return;
}
store_->Get(stream, num_cache_missing, keys_buffer_, values_buffer_, n_missing, indices_buffer1_);
OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, n_missing, sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync());
const uint32_t num_store_missing = *host_num_buffer_;
RUN_CUDA_KERNEL((PostStoreGetKernel<Key, Elem>), stream, num_cache_missing * num_elems_per_value_,
num_cache_missing, num_store_missing, num_elems_per_value_, indices_buffer0_,
indices_buffer1_, values_buffer_, static_cast<Elem*>(values), missing_indices);
}
template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::Get(ep::Stream* stream, uint32_t num_keys, const void* keys,
void* values, uint8_t* mask) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
if (cache_->Policy() == CacheOptions::Policy::kFull) {
cache_->Get(stream, num_keys, keys, values, mask);
return;
} else {
UNIMPLEMENTED();
}
}
template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::Put(ep::Stream* stream, uint32_t num_keys, const void* keys,
const void* values) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
synced_ = false;
auto cuda_stream = stream->As<ep::CudaStream>();
cache_->Put(stream, num_keys, keys, values, num_buffer_, keys_buffer_, values_buffer_);
if (cache_->Policy() == CacheOptions::Policy::kFull) { return; }
OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync());
store_->Put(stream, *host_num_buffer_, keys_buffer_, values_buffer_);
}
template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::FusedHalfUpdatePut(ep::Stream* stream, uint32_t num_keys,
const void* keys, const void* values,
const void* update, const float* lr,
float scale) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
if (cache_->Policy() != CacheOptions::Policy::kFull || cache_->ValueType() != DataType::kFloat) {
UNIMPLEMENTED();
}
synced_ = false;
cache_->FusedHalfUpdatePut(stream, num_keys, keys, values, update, lr, scale, num_buffer_,
keys_buffer_, values_buffer_);
}
template<typename Key, typename Elem>
bool CacheKeyValueStoreImpl<Key, Elem>::SnapshotExists(const std::string& name) {
return store_->SnapshotExists(name);
}
template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::LoadSnapshot(const std::string& name) {
LoadSnapshot(name, nullptr);
}
template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::LoadSnapshot(
const std::string& name, const std::function<void(KVIterator* iter)>& Hook) {
CudaCurrentDeviceGuard guard(device_index_);
std::lock_guard<std::recursive_mutex> lock(mutex_);
CHECK_GT(max_query_length_, 0);
cache_->Clear();
auto device =
Singleton<ep::DeviceManagerRegistry>::Get()->GetDevice(DeviceType::kCUDA, device_index_);
CHECK(device);
auto* stream = device->CreateStream();
store_->LoadSnapshot(name, [&](KVIterator* iter) {
if (cache_->Policy() == CacheOptions::Policy::kFull) {
auto* cuda_stream = stream->As<ep::CudaStream>();
while (true) {
iter->NextN(stream, max_query_length_, num_buffer_, keys_buffer_, values_buffer_);
OF_CUDA_CHECK(hipDeviceSynchronize());
OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t),
hipMemcpyDefault, cuda_stream->cuda_stream()));
CHECK_JUST(stream->Sync());
if (*host_num_buffer_ == 0) { return; }
cache_->Put(stream, *host_num_buffer_, keys_buffer_, values_buffer_, num_buffer_, nullptr,
nullptr);
OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t),
hipMemcpyDefault, cuda_stream->cuda_stream()));
CHECK_JUST(stream->Sync());
CHECK_EQ(*host_num_buffer_, 0);
}
}
if (Hook) {
iter->Reset();
Hook(iter);
}
});
device->DestroyStream(stream);
store_->LoadSnapshot(name);
}
template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::SaveSnapshot(const std::string& name) {
CudaCurrentDeviceGuard guard(device_index_);
std::lock_guard<std::recursive_mutex> lock(mutex_);
SyncCacheToStore();
store_->SaveSnapshot(name);
}
template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::SyncCacheToStore() {
if (synced_) { return; }
CudaCurrentDeviceGuard guard(device_index_);
auto device =
Singleton<ep::DeviceManagerRegistry>::Get()->GetDevice(DeviceType::kCUDA, device_index_);
CHECK(device);
auto* stream = device->CreateStream();
auto* cuda_stream = stream->As<ep::CudaStream>();
const uint64_t dump_capacity = cache_->DumpCapacity();
CHECK_GT(max_query_length_, 0);
for (uint64_t start_key_index = 0; start_key_index < dump_capacity;
start_key_index += max_query_length_) {
cache_->Dump(stream, start_key_index,
std::min(start_key_index + max_query_length_, dump_capacity), num_buffer_,
keys_buffer_, values_buffer_);
OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t),
hipMemcpyDefault, cuda_stream->cuda_stream()));
CHECK_JUST(stream->Sync());
if (*host_num_buffer_ == 0) { continue; }
store_->Put(stream, *host_num_buffer_, keys_buffer_, values_buffer_);
CHECK_JUST(stream->Sync());
}
device->DestroyStream(stream);
synced_ = true;
}
template<typename Key>
std::unique_ptr<KeyValueStore> DispatchElemType(std::unique_ptr<KeyValueStore>&& store,
std::unique_ptr<Cache>&& cache) {
const uint32_t value_size = store->ValueSize();
if (value_size % sizeof(uint4) == 0) {
return std::unique_ptr<KeyValueStore>(
new CacheKeyValueStoreImpl<Key, uint4>(std::move(store), std::move(cache)));
} else if (value_size % sizeof(uint64_t) == 0) {
return std::unique_ptr<KeyValueStore>(
new CacheKeyValueStoreImpl<Key, uint64_t>(std::move(store), std::move(cache)));
} else if (value_size % sizeof(uint32_t) == 0) {
return std::unique_ptr<KeyValueStore>(
new CacheKeyValueStoreImpl<Key, uint32_t>(std::move(store), std::move(cache)));
} else if (value_size % sizeof(uint16_t) == 0) {
return std::unique_ptr<KeyValueStore>(
new CacheKeyValueStoreImpl<Key, uint16_t>(std::move(store), std::move(cache)));
} else {
return std::unique_ptr<KeyValueStore>(
new CacheKeyValueStoreImpl<Key, uint8_t>(std::move(store), std::move(cache)));
}
}
std::unique_ptr<KeyValueStore> DispatchKeyType(std::unique_ptr<KeyValueStore>&& store,
std::unique_ptr<Cache>&& cache) {
const uint32_t key_size = store->KeySize();
if (key_size == 4) {
return DispatchElemType<uint32_t>(std::move(store), std::move(cache));
} else if (key_size == 8) {
return DispatchElemType<uint64_t>(std::move(store), std::move(cache));
} else {
UNIMPLEMENTED();
return nullptr;
}
}
} // namespace
std::unique_ptr<KeyValueStore> NewCachedKeyValueStore(std::unique_ptr<KeyValueStore>&& store,
std::unique_ptr<Cache>&& cache) {
return DispatchKeyType(std::move(store), std::move(cache));
}
} // namespace embedding
} // namespace oneflow
\ No newline at end of file
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "hip/hip_runtime.h"
#include "oneflow/core/embedding/full_cache.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/embedding/hash_functions.hip.h"
#include "oneflow/core/hip/atomic.hip.h"
namespace oneflow {
namespace embedding {
using Key32 = unsigned int;
using Key64 = unsigned long long int;
using Key128 = ulonglong2;
namespace {
template<typename Key, typename Index>
__device__ bool TryGetOrInsert(Key* entry_key, volatile Index* entry_index, Index* table_size,
Key key, Index* out) {
Key key_hi = (key | 0x1);
Key key_lo = (key & 0x1);
Index index_plus_one = 0;
Key old_entry_key = cuda::atomic::CAS(entry_key, static_cast<Key>(0), key_hi);
while (index_plus_one == 0) {
if (old_entry_key == static_cast<Key>(0)) {
Index index = cuda::atomic::Add(table_size, static_cast<Index>(1));
index_plus_one = index + 1;
*entry_index = ((index_plus_one << 1U) | key_lo);
*out = index_plus_one;
return true;
} else if (old_entry_key == key_hi) {
const Index entry_index_val = *entry_index;
if (entry_index_val == 0) {
// do nothing
} else if ((entry_index_val & 0x1) == key_lo) {
*out = (entry_index_val >> 1U);
return true;
} else {
return false;
}
} else {
return false;
}
}
return false;
}
template<typename Key, typename Index>
__device__ bool GetOrInsertOne(const size_t capacity, Key* table_keys, Index* table_indices,
Index* table_size, Key key, size_t hash, Index* out) {
const size_t start_idx = hash % capacity;
for (size_t count = 0; count < capacity; ++count) {
const size_t idx = (start_idx + count) % capacity;
Key* entry_key = table_keys + idx;
Index* entry_index = table_indices + idx;
if (TryGetOrInsert<Key, Index>(entry_key, entry_index, table_size, key, out)) { return true; }
}
return false;
}
template<typename Key, typename Index>
__device__ bool GetOne(const size_t capacity, Key* table_keys, Index* table_indices, Key key,
size_t hash, Index* out) {
const size_t start_idx = hash % capacity;
for (size_t count = 0; count < capacity; ++count) {
const size_t idx = (start_idx + count) % capacity;
Key entry_key = table_keys[idx];
Key entry_index = table_indices[idx];
Key key_hi = (key | 0x1);
Key key_lo = (key & 0x1);
if (entry_key == 0) { break; }
if (entry_key == key_hi) {
if ((entry_index & 0x1) == key_lo) {
*out = (entry_index >> 1U);
return true;
}
}
}
*out = 0;
return false;
}
template<typename Key, typename Index>
__global__ void OrdinalEncodeKernel(uint64_t capacity, Key* table_keys, Index* table_indices,
Index* table_size, uint32_t num_keys, const Key* keys,
Index* context) {
CUDA_1D_KERNEL_LOOP(i, num_keys) {
Key key = keys[i];
uint64_t hash = FullCacheHash()(key);
bool success = GetOrInsertOne<Key, Index>(capacity, table_keys, table_indices, table_size, key,
hash, context + i);
assert(success);
}
}
template<typename Key, typename Index>
__global__ void OrdinalEncodeLookupKernel(uint64_t capacity, Key* table_keys, Index* table_indices,
uint32_t num_keys, const Key* keys, Index* context) {
CUDA_1D_KERNEL_LOOP(i, num_keys) {
Key key = keys[i];
uint64_t hash = FullCacheHash()(key);
GetOne<Key, Index>(capacity, table_keys, table_indices, key, hash, context + i);
}
}
template<typename Key, typename Index>
__global__ void OrdinalEncodeDumpKernel(const Key* table_keys, const Index* table_indices,
uint64_t start_key_index, uint64_t end_key_index,
uint32_t* n_dumped, Key* keys, Index* context) {
CUDA_1D_KERNEL_LOOP(i, (end_key_index - start_key_index)) {
Key entry_key = table_keys[i + start_key_index];
Index entry_index = table_indices[i + start_key_index];
if (entry_index != 0) {
uint32_t index = cuda::atomic::Add(n_dumped, static_cast<uint32_t>(1));
keys[index] = ((entry_key ^ 0x1) | (entry_index & 0x1));
context[index] = (entry_index >> 1U);
}
}
}
template<typename Key, typename Elem, typename Index, bool return_value>
__global__ void LookupKernel(uint32_t value_length, const Elem* cache_values,
uint32_t values_elem_cnt, const Key* keys, const Index* context,
Elem* values, uint32_t* n_missing, Key* missing_keys,
uint32_t* missing_indices) {
CUDA_1D_KERNEL_LOOP(i, values_elem_cnt) {
const uint64_t key_id = i / value_length;
const uint64_t ctx = context[key_id];
const uint64_t row_id = ctx - 1;
const uint64_t col_id = i - key_id * value_length;
if (ctx == 0) {
const Key missing_key = keys[key_id];
if (col_id == 0) {
const uint32_t old_n_missing = cuda::atomic::Add(n_missing, static_cast<uint32_t>(1));
missing_keys[old_n_missing] = missing_key;
missing_indices[old_n_missing] = key_id;
}
continue;
}
if (return_value) { values[i] = cache_values[row_id * value_length + col_id]; }
}
}
template<typename Key, typename Elem, typename Index, uint32_t block_size>
__global__ void EncodeLookupKernel(uint32_t value_length, const Elem* cache_values,
uint32_t values_elem_cnt, const Key* keys, const Index* context,
Elem* values, uint32_t* n_missing, Key* missing_keys,
uint32_t* missing_indices, const size_t capacity,
Key* table_keys, Index* table_indices) {
constexpr uint32_t warp_size = 32;
constexpr uint32_t n_warp_per_block = block_size / warp_size;
const uint32_t warp_id = threadIdx.x / warp_size;
const uint32_t lane_id = threadIdx.x % warp_size;
const uint32_t global_warp_id = blockIdx.x * n_warp_per_block + warp_id;
const uint32_t global_n_warp = gridDim.x * n_warp_per_block;
const uint32_t n_keys = values_elem_cnt / value_length;
__shared__ Key batch_keys[n_warp_per_block][warp_size];
__shared__ Index batch_row_ids[n_warp_per_block][warp_size];
__shared__ Key batch_missing_keys[n_warp_per_block][warp_size];
__shared__ uint32_t batch_missing_indices[n_warp_per_block][warp_size];
__shared__ uint32_t batch_n_missing[n_warp_per_block];
for (uint32_t batch_start = global_warp_id * warp_size; batch_start < n_keys;
batch_start += global_n_warp * warp_size) {
const uint32_t batch_n_key = min(n_keys - batch_start, warp_size);
if (lane_id == 0) { batch_n_missing[warp_id] = 0; }
__syncthreads();
const uint32_t key_offset = batch_start + lane_id;
if (key_offset < n_keys) {
const Key key = keys[batch_start + lane_id];
const uint64_t hash = FullCacheHash()(key);
Index row;
GetOne<Key, Index>(capacity, table_keys, table_indices, key, hash, &row);
batch_row_ids[warp_id][lane_id] = row;
if (row == 0) {
const uint32_t batch_missing_idx = atomicAdd(batch_n_missing + warp_id, 1);
batch_missing_keys[warp_id][batch_missing_idx] = key;
batch_missing_indices[warp_id][batch_missing_idx] = key_offset;
}
}
__syncthreads();
const uint32_t batch_n_missing_t = batch_n_missing[warp_id];
if (lane_id == 0) {
const uint32_t old_n_missing =
cuda::atomic::Add(n_missing, static_cast<uint32_t>(batch_n_missing_t));
batch_n_missing[warp_id] = old_n_missing;
}
__syncthreads();
if (lane_id < batch_n_missing_t) {
missing_keys[batch_n_missing[warp_id] + lane_id] = batch_missing_keys[warp_id][lane_id];
missing_indices[batch_n_missing[warp_id] + lane_id] = batch_missing_indices[warp_id][lane_id];
}
for (int i = 0; i < batch_n_key; ++i) {
const Key key = batch_keys[warp_id][i];
const Index row = batch_row_ids[warp_id][i];
if (row == 0) { continue; }
for (int col = lane_id; col < value_length; col += warp_size) {
values[(batch_start + i) * value_length + col] =
cache_values[(row - 1) * value_length + col];
}
}
__syncthreads();
}
}
template<typename T, size_t pack_size>
struct alignas(sizeof(T) * pack_size) Pack {
T elem[pack_size];
};
template<typename Key, typename Elem, typename Index, uint32_t block_size, uint32_t pack_size>
__global__ void EncodeLookupMaskKernel(uint32_t value_length, const Elem* __restrict__ cache_values,
uint32_t values_elem_cnt, const Key* __restrict__ keys,
const Index* __restrict__ context, Elem* __restrict__ values,
uint8_t* __restrict__ mask, const size_t capacity,
Key* __restrict__ table_keys,
Index* __restrict__ table_indices) {
const uint32_t packed_cols = value_length / pack_size;
auto* packed_values = reinterpret_cast<Pack<Elem, pack_size>*>(values);
const auto* packed_cache_values = reinterpret_cast<const Pack<Elem, pack_size>*>(cache_values);
constexpr uint32_t warp_size = 32;
constexpr uint32_t n_warp_per_block = block_size / warp_size;
const uint32_t warp_id = threadIdx.x / warp_size;
const uint32_t lane_id = threadIdx.x % warp_size;
const uint32_t global_warp_id = blockIdx.x * n_warp_per_block + warp_id;
const uint32_t global_n_warp = gridDim.x * n_warp_per_block;
const uint32_t n_keys = values_elem_cnt / value_length;
__shared__ Key batch_keys[n_warp_per_block][warp_size];
__shared__ Index batch_row_ids[n_warp_per_block][warp_size];
for (uint32_t batch_start = global_warp_id * warp_size; batch_start < n_keys;
batch_start += global_n_warp * warp_size) {
const uint32_t batch_n_key = min(n_keys - batch_start, warp_size);
const uint32_t key_offset = batch_start + lane_id;
if (key_offset < n_keys) {
const Key key = keys[batch_start + lane_id];
const uint64_t hash = FullCacheHash()(key);
Index row;
GetOne<Key, Index>(capacity, table_keys, table_indices, key, hash, &row);
batch_row_ids[warp_id][lane_id] = row;
mask[key_offset] = row > 0;
}
__syncthreads();
for (int i = 0; i < batch_n_key; ++i) {
const Key key = batch_keys[warp_id][i];
const Index row = batch_row_ids[warp_id][i];
if (row == 0) { continue; }
#pragma unroll 4
for (int col = lane_id; col < packed_cols; col += warp_size) {
packed_values[(batch_start + i) * packed_cols + col] =
packed_cache_values[(row - 1) * packed_cols + col];
}
}
__syncthreads();
}
}
template<typename Elem, typename Index, size_t pack_size>
__global__ void UpdateKernel(uint32_t value_length, Elem* cache_values, uint32_t values_elem_cnt,
const Index* context, const Elem* values) {
const int packed_values_elem_cnt = values_elem_cnt / pack_size;
const uint32_t packed_elem_cnt = value_length / pack_size;
auto* packed_cache_values = reinterpret_cast<Pack<Elem, pack_size>*>(cache_values);
auto* packed_values = reinterpret_cast<const Pack<Elem, pack_size>*>(values);
CUDA_1D_KERNEL_LOOP(i, packed_values_elem_cnt) {
const uint64_t key_id = i / packed_elem_cnt;
const uint64_t ctx = context[key_id];
if (ctx == 0) { continue; }
const uint64_t row_id = ctx - 1;
const uint64_t col_id = i - key_id * packed_elem_cnt;
packed_cache_values[row_id * packed_elem_cnt + col_id] = packed_values[i];
}
}
template<typename Elem, typename Index, size_t pack_size>
__global__ typename std::enable_if<std::is_same<Elem, float>::value, void>::type
FusedHalfUpdateKernel(uint32_t value_length, Elem* __restrict__ cache_values,
uint32_t values_elem_cnt, const Index* __restrict__ context,
const Elem* __restrict__ values, const half* __restrict__ update,
const float* __restrict__ lr, float scale) {
const int packed_values_elem_cnt = values_elem_cnt / pack_size;
const uint32_t packed_elem_cnt = value_length / pack_size;
auto* packed_cache_values = reinterpret_cast<Pack<Elem, pack_size>*>(cache_values);
auto* packed_values = reinterpret_cast<const Pack<Elem, pack_size>*>(values);
auto* packed_update = reinterpret_cast<const Pack<half, pack_size>*>(update);
const float alpha = -*lr * scale;
CUDA_1D_KERNEL_LOOP(i, packed_values_elem_cnt) {
const uint64_t key_id = i / packed_elem_cnt;
const uint64_t ctx = context[key_id];
if (ctx == 0) { continue; }
const uint64_t row_id = ctx - 1;
const uint64_t col_id = i - key_id * packed_elem_cnt;
Pack<Elem, pack_size> m = packed_values[i];
Pack<half, pack_size> u = packed_update[i];
for (size_t j = 0; j < pack_size; ++j) { m.elem[j] += static_cast<Elem>(u.elem[j]) * alpha; }
packed_cache_values[row_id * packed_elem_cnt + col_id] = m;
}
}
template<typename Elem, typename Index, size_t pack_size>
__global__ typename std::enable_if<!std::is_same<Elem, float>::value, void>::type
FusedHalfUpdateKernel(uint32_t value_length, Elem* cache_values, uint32_t values_elem_cnt,
const Index* context, const Elem* values, const half* update, const float* lr,
float scale) {
asm volatile("s_trap 0;");
}
template<typename Key, typename Elem, typename Index>
__global__ void DumpValueKernel(uint32_t value_length, const uint32_t* n_dumped,
const Index* context, const Elem* cache_values, Elem* values) {
CUDA_1D_KERNEL_LOOP(i, *n_dumped * value_length) {
const uint64_t key_id = i / value_length;
const uint64_t ctx = context[key_id];
const uint64_t row_id = ctx - 1;
const uint64_t col_id = i - key_id * value_length;
values[i] = cache_values[row_id * value_length + col_id];
}
}
template<typename Key, typename Index>
class OrdinalEncoder {
public:
OF_DISALLOW_COPY_AND_MOVE(OrdinalEncoder);
explicit OrdinalEncoder(uint64_t capacity, float load_factor)
: capacity_(capacity), table_capacity_(capacity / load_factor) {
OF_CUDA_CHECK(hipGetDevice(&device_index_));
OF_CUDA_CHECK(hipMalloc(&table_size_, sizeof(Index)));
OF_CUDA_CHECK(hipMallocHost(reinterpret_cast<void **>(&table_size_host_), sizeof(Index)));
OF_CUDA_CHECK(hipMalloc(&table_keys_, table_capacity_ * sizeof(Key)));
OF_CUDA_CHECK(hipMalloc(&table_indices_, table_capacity_ * sizeof(Index)));
Clear();
}
~OrdinalEncoder() {
CudaCurrentDeviceGuard guard(device_index_);
OF_CUDA_CHECK(hipFree(table_size_));
OF_CUDA_CHECK(hipHostFree(table_size_host_));
OF_CUDA_CHECK(hipFree(table_keys_));
OF_CUDA_CHECK(hipFree(table_indices_));
}
template<bool insert>
void Encode(ep::Stream* stream, uint32_t num_keys, const Key* keys, Index* context) {
if (insert) {
RUN_CUDA_KERNEL((OrdinalEncodeKernel<Key, Index>), stream, num_keys, table_capacity_,
table_keys_, table_indices_, table_size_, num_keys, keys, context);
OF_CUDA_CHECK(hipMemcpyAsync(table_size_host_, table_size_, sizeof(Index), hipMemcpyDefault,
stream->As<ep::CudaStream>()->cuda_stream()));
CHECK_JUST(stream->Sync());
CHECK_LT(*table_size_host_, capacity_)
<< "The number of key is larger than cache size, please enlarge cache_memory_budget. ";
} else {
RUN_CUDA_KERNEL((OrdinalEncodeLookupKernel<Key, Index>), stream, num_keys, table_capacity_,
table_keys_, table_indices_, num_keys, keys, context);
}
}
void Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index,
uint32_t* n_dumped, Key* keys, Index* context) {
OF_CUDA_CHECK(hipMemsetAsync(n_dumped, 0, sizeof(uint32_t),
stream->As<ep::CudaStream>()->cuda_stream()));
RUN_CUDA_KERNEL((OrdinalEncodeDumpKernel<Key, Index>), stream, end_key_index - start_key_index,
table_keys_, table_indices_, start_key_index, end_key_index, n_dumped, keys,
context);
}
void Clear() {
OF_CUDA_CHECK(hipMemset(table_size_, 0, sizeof(Index)));
OF_CUDA_CHECK(hipMemset(table_keys_, 0, table_capacity_ * sizeof(Key)));
OF_CUDA_CHECK(hipMemset(table_indices_, 0, table_capacity_ * sizeof(Index)));
}
uint64_t TableCapacity() const { return table_capacity_; }
Key* table_keys() const { return table_keys_; }
Index* table_indices() const { return table_indices_; }
private:
int device_index_{};
Key* table_keys_;
Index* table_indices_;
uint64_t capacity_;
uint64_t table_capacity_;
Index* table_size_{};
Index* table_size_host_{};
};
template<typename Key, typename Elem, typename Index, size_t pack_size>
class CacheImpl : public Cache {
public:
OF_DISALLOW_COPY_AND_MOVE(CacheImpl);
explicit CacheImpl(const CacheOptions& options)
: encoder_(options.capacity, options.load_factor),
device_index_(-1),
options_(options),
max_query_length_(0) {
OF_CUDA_CHECK(hipGetDevice(&device_index_));
const uint64_t values_size = options.capacity * options.value_size;
if (options.value_memory_kind == CacheOptions::MemoryKind::kDevice) {
OF_CUDA_CHECK(hipMalloc(&values_, values_size));
} else if (options.value_memory_kind == CacheOptions::MemoryKind::kHost) {
if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_DISABLE_NUMA_AWARE_ALLOCATION", false)) {
OF_CUDA_CHECK(hipMallocHost(reinterpret_cast<void **>(&values_), values_size));
} else {
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, reinterpret_cast<void**>(&values_),
values_size));
}
} else {
UNIMPLEMENTED();
}
num_elem_per_value_ = options_.value_size / sizeof(Elem);
}
~CacheImpl() {
CudaCurrentDeviceGuard guard(device_index_);
if (options_.value_memory_kind == CacheOptions::MemoryKind::kDevice) {
OF_CUDA_CHECK(hipFree(values_));
} else if (options_.value_memory_kind == CacheOptions::MemoryKind::kHost) {
OF_CUDA_CHECK(hipHostFree(values_));
} else {
UNIMPLEMENTED();
}
if (max_query_length_ > 0) { OF_CUDA_CHECK(hipFree(encoding_buffer_)); }
}
uint64_t Capacity() const override { return options_.capacity; }
uint64_t DumpCapacity() const override { return encoder_.TableCapacity(); }
uint32_t KeySize() const override { return options_.key_size; }
uint32_t ValueSize() const override { return options_.value_size; }
DataType ValueType() const override { return options_.value_type; }
uint32_t MaxQueryLength() const override { return max_query_length_; }
void ReserveQueryLength(uint32_t query_length) override {
CudaCurrentDeviceGuard guard(device_index_);
if (query_length <= max_query_length_) { return; }
if (max_query_length_ > 0) { OF_CUDA_CHECK(hipFree(encoding_buffer_)); }
OF_CUDA_CHECK(hipMalloc(&encoding_buffer_, query_length * sizeof(uint64_t)));
max_query_length_ = query_length;
}
CacheOptions::Policy Policy() const override { return CacheOptions::Policy::kFull; }
void Test(ep::Stream* stream, uint32_t n_keys, const void* keys, uint32_t* n_missing,
void* missing_keys, uint32_t* missing_indices) override;
void Get(ep::Stream* stream, uint32_t n_keys, const void* keys, void* values, uint32_t* n_missing,
void* missing_keys, uint32_t* missing_indices) override;
void Get(ep::Stream* stream, uint32_t n_keys, const void* keys, void* values,
uint8_t* mask) override;
void Put(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values,
uint32_t* n_evicted, void* evicted_keys, void* evicted_values) override;
void FusedHalfUpdatePut(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values,
const void* update, const float* lr, float scale, uint32_t* n_evicted,
void* evicted_keys, void* evicted_values) override;
void Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index,
uint32_t* n_dumped, void* keys, void* values) override;
void Clear() override;
private:
OrdinalEncoder<Key, Index> encoder_;
int device_index_;
uint32_t num_elem_per_value_{};
Elem* values_;
Index* encoding_buffer_{};
CacheOptions options_;
uint32_t max_query_length_;
};
template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::Test(ep::Stream* stream, uint32_t n_keys,
const void* keys, uint32_t* n_missing,
void* missing_keys, uint32_t* missing_indices) {
OF_CUDA_CHECK(
hipMemsetAsync(n_missing, 0, sizeof(uint32_t), stream->As<ep::CudaStream>()->cuda_stream()));
if (n_keys == 0) { return; }
CHECK_LE(n_keys, max_query_length_);
encoder_.template Encode<false>(stream, n_keys, static_cast<const Key*>(keys), encoding_buffer_);
const uint32_t values_elem_cnt = n_keys * num_elem_per_value_;
RUN_CUDA_KERNEL((LookupKernel<Key, Elem, Index, false>), stream, values_elem_cnt,
num_elem_per_value_, values_, values_elem_cnt, static_cast<const Key*>(keys),
encoding_buffer_, nullptr, n_missing, static_cast<Key*>(missing_keys),
missing_indices);
}
template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::Get(ep::Stream* stream, uint32_t n_keys,
const void* keys, void* values,
uint32_t* n_missing, void* missing_keys,
uint32_t* missing_indices) {
OF_CUDA_CHECK(
hipMemsetAsync(n_missing, 0, sizeof(uint32_t), stream->As<ep::CudaStream>()->cuda_stream()));
if (n_keys == 0) { return; }
CHECK_LE(n_keys, max_query_length_);
constexpr uint32_t block_size = 128;
uint32_t grid_size = (n_keys + block_size - 1) / block_size;
const uint32_t values_elem_cnt = n_keys * num_elem_per_value_;
EncodeLookupKernel<Key, Elem, Index, block_size>
<<<grid_size, block_size, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
num_elem_per_value_, values_, values_elem_cnt, static_cast<const Key*>(keys),
encoding_buffer_, static_cast<Elem*>(values), n_missing, static_cast<Key*>(missing_keys),
missing_indices, encoder_.TableCapacity(), encoder_.table_keys(),
encoder_.table_indices());
}
template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::Get(ep::Stream* stream, uint32_t n_keys,
const void* keys, void* values, uint8_t* mask) {
if (n_keys == 0) { return; }
CHECK_LE(n_keys, max_query_length_);
constexpr uint32_t block_size = 128;
uint32_t grid_size = (n_keys + block_size - 1) / block_size;
const uint32_t values_elem_cnt = n_keys * num_elem_per_value_;
EncodeLookupMaskKernel<Key, Elem, Index, block_size, pack_size>
<<<grid_size, block_size, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
num_elem_per_value_, values_, values_elem_cnt, static_cast<const Key*>(keys),
encoding_buffer_, static_cast<Elem*>(values), mask, encoder_.TableCapacity(),
encoder_.table_keys(), encoder_.table_indices());
}
template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::Put(ep::Stream* stream, uint32_t n_keys,
const void* keys, const void* values,
uint32_t* n_evicted, void* evicted_keys,
void* evicted_values) {
OF_CUDA_CHECK(
hipMemsetAsync(n_evicted, 0, sizeof(uint32_t), stream->As<ep::CudaStream>()->cuda_stream()));
if (n_keys == 0) { return; }
CHECK_LE(n_keys, max_query_length_);
encoder_.template Encode<true>(stream, n_keys, static_cast<const Key*>(keys), encoding_buffer_);
const uint32_t values_elem_cnt = n_keys * num_elem_per_value_;
RUN_CUDA_KERNEL((UpdateKernel<Elem, Index, pack_size>), stream, values_elem_cnt / pack_size,
num_elem_per_value_, values_, values_elem_cnt, encoding_buffer_,
static_cast<const Elem*>(values));
}
template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::FusedHalfUpdatePut(
ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values, const void* update,
const float* lr, float scale, uint32_t* n_evicted, void* evicted_keys, void* evicted_values) {
if (!std::is_same<Elem, float>::value) { UNIMPLEMENTED(); }
OF_CUDA_CHECK(
hipMemsetAsync(n_evicted, 0, sizeof(uint32_t), stream->As<ep::CudaStream>()->cuda_stream()));
if (n_keys == 0) { return; }
CHECK_LE(n_keys, max_query_length_);
encoder_.template Encode<true>(stream, n_keys, static_cast<const Key*>(keys), encoding_buffer_);
const uint32_t values_elem_cnt = n_keys * num_elem_per_value_;
RUN_CUDA_KERNEL((FusedHalfUpdateKernel<Elem, Index, pack_size>), stream,
values_elem_cnt / pack_size, num_elem_per_value_, values_, values_elem_cnt,
encoding_buffer_, static_cast<const Elem*>(values),
static_cast<const half*>(update), lr, scale);
}
template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::Dump(ep::Stream* stream, uint64_t start_key_index,
uint64_t end_key_index, uint32_t* n_dumped,
void* keys, void* values) {
encoder_.Dump(stream, start_key_index, end_key_index, n_dumped, static_cast<Key*>(keys),
encoding_buffer_);
RUN_CUDA_KERNEL((DumpValueKernel<Key, Elem, Index>), stream,
num_elem_per_value_ * (end_key_index - start_key_index), num_elem_per_value_,
n_dumped, encoding_buffer_, values_, static_cast<Elem*>(values));
}
template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::Clear() {
encoder_.Clear();
}
template<typename Key, typename Index>
std::unique_ptr<Cache> DispatchValueType(const CacheOptions& options) {
if (options.value_type == DataType::kFloat) {
const size_t value_elem_cnt = options.value_size / sizeof(float);
const size_t half_warp = 16;
if (value_elem_cnt % 4 == 0 && value_elem_cnt / 4 > half_warp) {
return std::unique_ptr<Cache>(new CacheImpl<Key, float, Index, 4>(options));
} else if (value_elem_cnt % 2 == 0 && value_elem_cnt / 2 > half_warp) {
return std::unique_ptr<Cache>(new CacheImpl<Key, float, Index, 2>(options));
} else {
return std::unique_ptr<Cache>(new CacheImpl<Key, float, Index, 1>(options));
}
} else if (options.value_size % sizeof(ulonglong2) == 0) {
return std::unique_ptr<Cache>(new CacheImpl<Key, ulonglong2, Index, 1>(options));
} else if (options.value_size % sizeof(uint64_t) == 0) {
return std::unique_ptr<Cache>(new CacheImpl<Key, uint64_t, Index, 1>(options));
} else if (options.value_size % sizeof(uint32_t) == 0) {
return std::unique_ptr<Cache>(new CacheImpl<Key, uint32_t, Index, 1>(options));
} else if (options.value_size % sizeof(uint16_t) == 0) {
return std::unique_ptr<Cache>(new CacheImpl<Key, uint16_t, Index, 1>(options));
} else {
return std::unique_ptr<Cache>(new CacheImpl<Key, uint8_t, Index, 1>(options));
}
}
template<typename Index>
std::unique_ptr<Cache> DispatchKeyType(const CacheOptions& options) {
if (options.key_size == sizeof(Key32)) {
return DispatchValueType<Key32, Index>(options);
} else if (options.key_size == sizeof(Key64)) {
return DispatchValueType<Key64, Index>(options);
} else {
UNIMPLEMENTED();
return nullptr;
}
}
std::unique_ptr<Cache> DispatchIndexType(const CacheOptions& options) {
const int64_t table_capacity = static_cast<double>(options.capacity) / options.load_factor;
if (table_capacity >= (1ULL << 31ULL)) {
return DispatchKeyType<uint64_t>(options);
} else {
return DispatchKeyType<uint32_t>(options);
}
}
} // namespace
std::unique_ptr<Cache> NewFullCache(const CacheOptions& options) {
return DispatchIndexType(options);
}
} // namespace embedding
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "hip/hip_runtime.h"
#include "oneflow/core/embedding/full_cache.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/embedding/hash_functions.hip.h"
#include "oneflow/core/hip/atomic.hip.h"
namespace oneflow {
namespace embedding {
using Key32 = unsigned int;
using Key64 = unsigned long long int;
using Key128 = ulonglong2;
namespace {
template<typename Key, typename Index>
__device__ bool TryGetOrInsert(Key* entry_key, volatile Index* entry_index, Index* table_size,
Key key, Index* out) {
Key key_hi = (key | 0x1);
Key key_lo = (key & 0x1);
Index index_plus_one = 0;
Key old_entry_key = cuda::atomic::CAS(entry_key, static_cast<Key>(0), key_hi);
while (index_plus_one == 0) {
if (old_entry_key == static_cast<Key>(0)) {
Index index = cuda::atomic::Add(table_size, static_cast<Index>(1));
index_plus_one = index + 1;
*entry_index = ((index_plus_one << 1U) | key_lo);
*out = index_plus_one;
return true;
} else if (old_entry_key == key_hi) {
const Index entry_index_val = *entry_index;
if (entry_index_val == 0) {
// do nothing
} else if ((entry_index_val & 0x1) == key_lo) {
*out = (entry_index_val >> 1U);
return true;
} else {
return false;
}
} else {
return false;
}
}
return false;
}
template<typename Key, typename Index>
__device__ bool GetOrInsertOne(const size_t capacity, Key* table_keys, Index* table_indices,
Index* table_size, Key key, size_t hash, Index* out) {
const size_t start_idx = hash % capacity;
for (size_t count = 0; count < capacity; ++count) {
const size_t idx = (start_idx + count) % capacity;
Key* entry_key = table_keys + idx;
Index* entry_index = table_indices + idx;
if (TryGetOrInsert<Key, Index>(entry_key, entry_index, table_size, key, out)) { return true; }
}
return false;
}
template<typename Key, typename Index>
__device__ bool GetOne(const size_t capacity, Key* table_keys, Index* table_indices, Key key,
size_t hash, Index* out) {
const size_t start_idx = hash % capacity;
for (size_t count = 0; count < capacity; ++count) {
const size_t idx = (start_idx + count) % capacity;
Key entry_key = table_keys[idx];
Key entry_index = table_indices[idx];
Key key_hi = (key | 0x1);
Key key_lo = (key & 0x1);
if (entry_key == 0) { break; }
if (entry_key == key_hi) {
if ((entry_index & 0x1) == key_lo) {
*out = (entry_index >> 1U);
return true;
}
}
}
*out = 0;
return false;
}
template<typename Key, typename Index>
__global__ void OrdinalEncodeKernel(uint64_t capacity, Key* table_keys, Index* table_indices,
Index* table_size, uint32_t num_keys, const Key* keys,
Index* context) {
CUDA_1D_KERNEL_LOOP(i, num_keys) {
Key key = keys[i];
uint64_t hash = FullCacheHash()(key);
bool success = GetOrInsertOne<Key, Index>(capacity, table_keys, table_indices, table_size, key,
hash, context + i);
assert(success);
}
}
template<typename Key, typename Index>
__global__ void OrdinalEncodeLookupKernel(uint64_t capacity, Key* table_keys, Index* table_indices,
uint32_t num_keys, const Key* keys, Index* context) {
CUDA_1D_KERNEL_LOOP(i, num_keys) {
Key key = keys[i];
uint64_t hash = FullCacheHash()(key);
GetOne<Key, Index>(capacity, table_keys, table_indices, key, hash, context + i);
}
}
template<typename Key, typename Index>
__global__ void OrdinalEncodeDumpKernel(const Key* table_keys, const Index* table_indices,
uint64_t start_key_index, uint64_t end_key_index,
uint32_t* n_dumped, Key* keys, Index* context) {
CUDA_1D_KERNEL_LOOP(i, (end_key_index - start_key_index)) {
Key entry_key = table_keys[i + start_key_index];
Index entry_index = table_indices[i + start_key_index];
if (entry_index != 0) {
uint32_t index = cuda::atomic::Add(n_dumped, static_cast<uint32_t>(1));
keys[index] = ((entry_key ^ 0x1) | (entry_index & 0x1));
context[index] = (entry_index >> 1U);
}
}
}
template<typename Key, typename Elem, typename Index, bool return_value>
__global__ void LookupKernel(uint32_t value_length, const Elem* cache_values,
uint32_t values_elem_cnt, const Key* keys, const Index* context,
Elem* values, uint32_t* n_missing, Key* missing_keys,
uint32_t* missing_indices) {
CUDA_1D_KERNEL_LOOP(i, values_elem_cnt) {
const uint64_t key_id = i / value_length;
const uint64_t ctx = context[key_id];
const uint64_t row_id = ctx - 1;
const uint64_t col_id = i - key_id * value_length;
if (ctx == 0) {
const Key missing_key = keys[key_id];
if (col_id == 0) {
const uint32_t old_n_missing = cuda::atomic::Add(n_missing, static_cast<uint32_t>(1));
missing_keys[old_n_missing] = missing_key;
missing_indices[old_n_missing] = key_id;
}
continue;
}
if (return_value) { values[i] = cache_values[row_id * value_length + col_id]; }
}
}
template<typename Key, typename Elem, typename Index, uint32_t block_size>
__global__ void EncodeLookupKernel(uint32_t value_length, const Elem* cache_values,
uint32_t values_elem_cnt, const Key* keys, const Index* context,
Elem* values, uint32_t* n_missing, Key* missing_keys,
uint32_t* missing_indices, const size_t capacity,
Key* table_keys, Index* table_indices) {
constexpr uint32_t warp_size = 32;
constexpr uint32_t n_warp_per_block = block_size / warp_size;
const uint32_t warp_id = threadIdx.x / warp_size;
const uint32_t lane_id = threadIdx.x % warp_size;
const uint32_t global_warp_id = blockIdx.x * n_warp_per_block + warp_id;
const uint32_t global_n_warp = gridDim.x * n_warp_per_block;
const uint32_t n_keys = values_elem_cnt / value_length;
__shared__ Key batch_keys[n_warp_per_block][warp_size];
__shared__ Index batch_row_ids[n_warp_per_block][warp_size];
__shared__ Key batch_missing_keys[n_warp_per_block][warp_size];
__shared__ uint32_t batch_missing_indices[n_warp_per_block][warp_size];
__shared__ uint32_t batch_n_missing[n_warp_per_block];
for (uint32_t batch_start = global_warp_id * warp_size; batch_start < n_keys;
batch_start += global_n_warp * warp_size) {
const uint32_t batch_n_key = min(n_keys - batch_start, warp_size);
if (lane_id == 0) { batch_n_missing[warp_id] = 0; }
__syncthreads();
const uint32_t key_offset = batch_start + lane_id;
if (key_offset < n_keys) {
const Key key = keys[batch_start + lane_id];
const uint64_t hash = FullCacheHash()(key);
Index row;
GetOne<Key, Index>(capacity, table_keys, table_indices, key, hash, &row);
batch_row_ids[warp_id][lane_id] = row;
if (row == 0) {
const uint32_t batch_missing_idx = atomicAdd(batch_n_missing + warp_id, 1);
batch_missing_keys[warp_id][batch_missing_idx] = key;
batch_missing_indices[warp_id][batch_missing_idx] = key_offset;
}
}
__syncthreads();
const uint32_t batch_n_missing_t = batch_n_missing[warp_id];
if (lane_id == 0) {
const uint32_t old_n_missing =
cuda::atomic::Add(n_missing, static_cast<uint32_t>(batch_n_missing_t));
batch_n_missing[warp_id] = old_n_missing;
}
__syncthreads();
if (lane_id < batch_n_missing_t) {
missing_keys[batch_n_missing[warp_id] + lane_id] = batch_missing_keys[warp_id][lane_id];
missing_indices[batch_n_missing[warp_id] + lane_id] = batch_missing_indices[warp_id][lane_id];
}
for (int i = 0; i < batch_n_key; ++i) {
const Key key = batch_keys[warp_id][i];
const Index row = batch_row_ids[warp_id][i];
if (row == 0) { continue; }
for (int col = lane_id; col < value_length; col += warp_size) {
values[(batch_start + i) * value_length + col] =
cache_values[(row - 1) * value_length + col];
}
}
__syncthreads();
}
}
template<typename T, size_t pack_size>
struct alignas(sizeof(T) * pack_size) Pack {
T elem[pack_size];
};
template<typename Key, typename Elem, typename Index, uint32_t block_size, uint32_t pack_size>
__global__ void EncodeLookupMaskKernel(uint32_t value_length, const Elem* __restrict__ cache_values,
uint32_t values_elem_cnt, const Key* __restrict__ keys,
const Index* __restrict__ context, Elem* __restrict__ values,
uint8_t* __restrict__ mask, const size_t capacity,
Key* __restrict__ table_keys,
Index* __restrict__ table_indices) {
const uint32_t packed_cols = value_length / pack_size;
auto* packed_values = reinterpret_cast<Pack<Elem, pack_size>*>(values);
const auto* packed_cache_values = reinterpret_cast<const Pack<Elem, pack_size>*>(cache_values);
constexpr uint32_t warp_size = 32;
constexpr uint32_t n_warp_per_block = block_size / warp_size;
const uint32_t warp_id = threadIdx.x / warp_size;
const uint32_t lane_id = threadIdx.x % warp_size;
const uint32_t global_warp_id = blockIdx.x * n_warp_per_block + warp_id;
const uint32_t global_n_warp = gridDim.x * n_warp_per_block;
const uint32_t n_keys = values_elem_cnt / value_length;
__shared__ Key batch_keys[n_warp_per_block][warp_size];
__shared__ Index batch_row_ids[n_warp_per_block][warp_size];
for (uint32_t batch_start = global_warp_id * warp_size; batch_start < n_keys;
batch_start += global_n_warp * warp_size) {
const uint32_t batch_n_key = min(n_keys - batch_start, warp_size);
const uint32_t key_offset = batch_start + lane_id;
if (key_offset < n_keys) {
const Key key = keys[batch_start + lane_id];
const uint64_t hash = FullCacheHash()(key);
Index row;
GetOne<Key, Index>(capacity, table_keys, table_indices, key, hash, &row);
batch_row_ids[warp_id][lane_id] = row;
mask[key_offset] = row > 0;
}
__syncthreads();
for (int i = 0; i < batch_n_key; ++i) {
const Key key = batch_keys[warp_id][i];
const Index row = batch_row_ids[warp_id][i];
if (row == 0) { continue; }
#pragma unroll 4
for (int col = lane_id; col < packed_cols; col += warp_size) {
packed_values[(batch_start + i) * packed_cols + col] =
packed_cache_values[(row - 1) * packed_cols + col];
}
}
__syncthreads();
}
}
template<typename Elem, typename Index, size_t pack_size>
__global__ void UpdateKernel(uint32_t value_length, Elem* cache_values, uint32_t values_elem_cnt,
const Index* context, const Elem* values) {
const int packed_values_elem_cnt = values_elem_cnt / pack_size;
const uint32_t packed_elem_cnt = value_length / pack_size;
auto* packed_cache_values = reinterpret_cast<Pack<Elem, pack_size>*>(cache_values);
auto* packed_values = reinterpret_cast<const Pack<Elem, pack_size>*>(values);
CUDA_1D_KERNEL_LOOP(i, packed_values_elem_cnt) {
const uint64_t key_id = i / packed_elem_cnt;
const uint64_t ctx = context[key_id];
if (ctx == 0) { continue; }
const uint64_t row_id = ctx - 1;
const uint64_t col_id = i - key_id * packed_elem_cnt;
packed_cache_values[row_id * packed_elem_cnt + col_id] = packed_values[i];
}
}
template<typename Elem, typename Index, size_t pack_size>
__global__ typename std::enable_if<std::is_same<Elem, float>::value, void>::type
FusedHalfUpdateKernel(uint32_t value_length, Elem* __restrict__ cache_values,
uint32_t values_elem_cnt, const Index* __restrict__ context,
const Elem* __restrict__ values, const half* __restrict__ update,
const float* __restrict__ lr, float scale) {
const int packed_values_elem_cnt = values_elem_cnt / pack_size;
const uint32_t packed_elem_cnt = value_length / pack_size;
auto* packed_cache_values = reinterpret_cast<Pack<Elem, pack_size>*>(cache_values);
auto* packed_values = reinterpret_cast<const Pack<Elem, pack_size>*>(values);
auto* packed_update = reinterpret_cast<const Pack<half, pack_size>*>(update);
const float alpha = -*lr * scale;
CUDA_1D_KERNEL_LOOP(i, packed_values_elem_cnt) {
const uint64_t key_id = i / packed_elem_cnt;
const uint64_t ctx = context[key_id];
if (ctx == 0) { continue; }
const uint64_t row_id = ctx - 1;
const uint64_t col_id = i - key_id * packed_elem_cnt;
Pack<Elem, pack_size> m = packed_values[i];
Pack<half, pack_size> u = packed_update[i];
for (size_t j = 0; j < pack_size; ++j) { m.elem[j] += static_cast<Elem>(u.elem[j]) * alpha; }
packed_cache_values[row_id * packed_elem_cnt + col_id] = m;
}
}
template<typename Elem, typename Index, size_t pack_size>
__global__ typename std::enable_if<!std::is_same<Elem, float>::value, void>::type
FusedHalfUpdateKernel(uint32_t value_length, Elem* cache_values, uint32_t values_elem_cnt,
const Index* context, const Elem* values, const half* update, const float* lr,
float scale) {
asm volatile("s_trap 0;");
}
template<typename Key, typename Elem, typename Index>
__global__ void DumpValueKernel(uint32_t value_length, const uint32_t* n_dumped,
const Index* context, const Elem* cache_values, Elem* values) {
CUDA_1D_KERNEL_LOOP(i, *n_dumped * value_length) {
const uint64_t key_id = i / value_length;
const uint64_t ctx = context[key_id];
const uint64_t row_id = ctx - 1;
const uint64_t col_id = i - key_id * value_length;
values[i] = cache_values[row_id * value_length + col_id];
}
}
template<typename Key, typename Index>
class OrdinalEncoder {
public:
OF_DISALLOW_COPY_AND_MOVE(OrdinalEncoder);
explicit OrdinalEncoder(uint64_t capacity, float load_factor)
: capacity_(capacity), table_capacity_(capacity / load_factor) {
OF_CUDA_CHECK(hipGetDevice(&device_index_));
OF_CUDA_CHECK(hipMalloc(&table_size_, sizeof(Index)));
OF_CUDA_CHECK(hipMallocHost(reinterpret_cast<void **>(&table_size_host_), sizeof(Index)));
OF_CUDA_CHECK(hipMalloc(&table_keys_, table_capacity_ * sizeof(Key)));
OF_CUDA_CHECK(hipMalloc(&table_indices_, table_capacity_ * sizeof(Index)));
Clear();
}
~OrdinalEncoder() {
CudaCurrentDeviceGuard guard(device_index_);
OF_CUDA_CHECK(hipFree(table_size_));
OF_CUDA_CHECK(hipHostFree(table_size_host_));
OF_CUDA_CHECK(hipFree(table_keys_));
OF_CUDA_CHECK(hipFree(table_indices_));
}
template<bool insert>
void Encode(ep::Stream* stream, uint32_t num_keys, const Key* keys, Index* context) {
if (insert) {
RUN_CUDA_KERNEL((OrdinalEncodeKernel<Key, Index>), stream, num_keys, table_capacity_,
table_keys_, table_indices_, table_size_, num_keys, keys, context);
OF_CUDA_CHECK(hipMemcpyAsync(table_size_host_, table_size_, sizeof(Index), hipMemcpyDefault,
stream->As<ep::CudaStream>()->cuda_stream()));
CHECK_JUST(stream->Sync());
CHECK_LT(*table_size_host_, capacity_)
<< "The number of key is larger than cache size, please enlarge cache_memory_budget. ";
} else {
RUN_CUDA_KERNEL((OrdinalEncodeLookupKernel<Key, Index>), stream, num_keys, table_capacity_,
table_keys_, table_indices_, num_keys, keys, context);
}
}
void Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index,
uint32_t* n_dumped, Key* keys, Index* context) {
OF_CUDA_CHECK(hipMemsetAsync(n_dumped, 0, sizeof(uint32_t),
stream->As<ep::CudaStream>()->cuda_stream()));
RUN_CUDA_KERNEL((OrdinalEncodeDumpKernel<Key, Index>), stream, end_key_index - start_key_index,
table_keys_, table_indices_, start_key_index, end_key_index, n_dumped, keys,
context);
}
void Clear() {
OF_CUDA_CHECK(hipMemset(table_size_, 0, sizeof(Index)));
OF_CUDA_CHECK(hipMemset(table_keys_, 0, table_capacity_ * sizeof(Key)));
OF_CUDA_CHECK(hipMemset(table_indices_, 0, table_capacity_ * sizeof(Index)));
}
uint64_t TableCapacity() const { return table_capacity_; }
Key* table_keys() const { return table_keys_; }
Index* table_indices() const { return table_indices_; }
private:
int device_index_{};
Key* table_keys_;
Index* table_indices_;
uint64_t capacity_;
uint64_t table_capacity_;
Index* table_size_{};
Index* table_size_host_{};
};
template<typename Key, typename Elem, typename Index, size_t pack_size>
class CacheImpl : public Cache {
public:
OF_DISALLOW_COPY_AND_MOVE(CacheImpl);
explicit CacheImpl(const CacheOptions& options)
: encoder_(options.capacity, options.load_factor),
device_index_(-1),
options_(options),
max_query_length_(0) {
OF_CUDA_CHECK(hipGetDevice(&device_index_));
const uint64_t values_size = options.capacity * options.value_size;
if (options.value_memory_kind == CacheOptions::MemoryKind::kDevice) {
OF_CUDA_CHECK(hipMalloc(&values_, values_size));
} else if (options.value_memory_kind == CacheOptions::MemoryKind::kHost) {
if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_DISABLE_NUMA_AWARE_ALLOCATION", false)) {
OF_CUDA_CHECK(hipMallocHost(reinterpret_cast<void **>(&values_), values_size));
} else {
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, reinterpret_cast<void**>(&values_),
values_size));
}
} else {
UNIMPLEMENTED();
}
num_elem_per_value_ = options_.value_size / sizeof(Elem);
}
~CacheImpl() {
CudaCurrentDeviceGuard guard(device_index_);
if (options_.value_memory_kind == CacheOptions::MemoryKind::kDevice) {
OF_CUDA_CHECK(hipFree(values_));
} else if (options_.value_memory_kind == CacheOptions::MemoryKind::kHost) {
OF_CUDA_CHECK(hipHostFree(values_));
} else {
UNIMPLEMENTED();
}
if (max_query_length_ > 0) { OF_CUDA_CHECK(hipFree(encoding_buffer_)); }
}
uint64_t Capacity() const override { return options_.capacity; }
uint64_t DumpCapacity() const override { return encoder_.TableCapacity(); }
uint32_t KeySize() const override { return options_.key_size; }
uint32_t ValueSize() const override { return options_.value_size; }
DataType ValueType() const override { return options_.value_type; }
uint32_t MaxQueryLength() const override { return max_query_length_; }
void ReserveQueryLength(uint32_t query_length) override {
CudaCurrentDeviceGuard guard(device_index_);
if (query_length <= max_query_length_) { return; }
if (max_query_length_ > 0) { OF_CUDA_CHECK(hipFree(encoding_buffer_)); }
OF_CUDA_CHECK(hipMalloc(&encoding_buffer_, query_length * sizeof(uint64_t)));
max_query_length_ = query_length;
}
CacheOptions::Policy Policy() const override { return CacheOptions::Policy::kFull; }
void Test(ep::Stream* stream, uint32_t n_keys, const void* keys, uint32_t* n_missing,
void* missing_keys, uint32_t* missing_indices) override;
void Get(ep::Stream* stream, uint32_t n_keys, const void* keys, void* values, uint32_t* n_missing,
void* missing_keys, uint32_t* missing_indices) override;
void Get(ep::Stream* stream, uint32_t n_keys, const void* keys, void* values,
uint8_t* mask) override;
void Put(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values,
uint32_t* n_evicted, void* evicted_keys, void* evicted_values) override;
void FusedHalfUpdatePut(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values,
const void* update, const float* lr, float scale, uint32_t* n_evicted,
void* evicted_keys, void* evicted_values) override;
void Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index,
uint32_t* n_dumped, void* keys, void* values) override;
void Clear() override;
private:
OrdinalEncoder<Key, Index> encoder_;
int device_index_;
uint32_t num_elem_per_value_{};
Elem* values_;
Index* encoding_buffer_{};
CacheOptions options_;
uint32_t max_query_length_;
};
template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::Test(ep::Stream* stream, uint32_t n_keys,
const void* keys, uint32_t* n_missing,
void* missing_keys, uint32_t* missing_indices) {
OF_CUDA_CHECK(
hipMemsetAsync(n_missing, 0, sizeof(uint32_t), stream->As<ep::CudaStream>()->cuda_stream()));
if (n_keys == 0) { return; }
CHECK_LE(n_keys, max_query_length_);
encoder_.template Encode<false>(stream, n_keys, static_cast<const Key*>(keys), encoding_buffer_);
const uint32_t values_elem_cnt = n_keys * num_elem_per_value_;
RUN_CUDA_KERNEL((LookupKernel<Key, Elem, Index, false>), stream, values_elem_cnt,
num_elem_per_value_, values_, values_elem_cnt, static_cast<const Key*>(keys),
encoding_buffer_, nullptr, n_missing, static_cast<Key*>(missing_keys),
missing_indices);
}
template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::Get(ep::Stream* stream, uint32_t n_keys,
const void* keys, void* values,
uint32_t* n_missing, void* missing_keys,
uint32_t* missing_indices) {
OF_CUDA_CHECK(
hipMemsetAsync(n_missing, 0, sizeof(uint32_t), stream->As<ep::CudaStream>()->cuda_stream()));
if (n_keys == 0) { return; }
CHECK_LE(n_keys, max_query_length_);
constexpr uint32_t block_size = 128;
uint32_t grid_size = (n_keys + block_size - 1) / block_size;
const uint32_t values_elem_cnt = n_keys * num_elem_per_value_;
EncodeLookupKernel<Key, Elem, Index, block_size>
<<<grid_size, block_size, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
num_elem_per_value_, values_, values_elem_cnt, static_cast<const Key*>(keys),
encoding_buffer_, static_cast<Elem*>(values), n_missing, static_cast<Key*>(missing_keys),
missing_indices, encoder_.TableCapacity(), encoder_.table_keys(),
encoder_.table_indices());
}
template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::Get(ep::Stream* stream, uint32_t n_keys,
const void* keys, void* values, uint8_t* mask) {
if (n_keys == 0) { return; }
CHECK_LE(n_keys, max_query_length_);
constexpr uint32_t block_size = 128;
uint32_t grid_size = (n_keys + block_size - 1) / block_size;
const uint32_t values_elem_cnt = n_keys * num_elem_per_value_;
EncodeLookupMaskKernel<Key, Elem, Index, block_size, pack_size>
<<<grid_size, block_size, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
num_elem_per_value_, values_, values_elem_cnt, static_cast<const Key*>(keys),
encoding_buffer_, static_cast<Elem*>(values), mask, encoder_.TableCapacity(),
encoder_.table_keys(), encoder_.table_indices());
}
template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::Put(ep::Stream* stream, uint32_t n_keys,
const void* keys, const void* values,
uint32_t* n_evicted, void* evicted_keys,
void* evicted_values) {
OF_CUDA_CHECK(
hipMemsetAsync(n_evicted, 0, sizeof(uint32_t), stream->As<ep::CudaStream>()->cuda_stream()));
if (n_keys == 0) { return; }
CHECK_LE(n_keys, max_query_length_);
encoder_.template Encode<true>(stream, n_keys, static_cast<const Key*>(keys), encoding_buffer_);
const uint32_t values_elem_cnt = n_keys * num_elem_per_value_;
RUN_CUDA_KERNEL((UpdateKernel<Elem, Index, pack_size>), stream, values_elem_cnt / pack_size,
num_elem_per_value_, values_, values_elem_cnt, encoding_buffer_,
static_cast<const Elem*>(values));
}
template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::FusedHalfUpdatePut(
ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values, const void* update,
const float* lr, float scale, uint32_t* n_evicted, void* evicted_keys, void* evicted_values) {
if (!std::is_same<Elem, float>::value) { UNIMPLEMENTED(); }
OF_CUDA_CHECK(
hipMemsetAsync(n_evicted, 0, sizeof(uint32_t), stream->As<ep::CudaStream>()->cuda_stream()));
if (n_keys == 0) { return; }
CHECK_LE(n_keys, max_query_length_);
encoder_.template Encode<true>(stream, n_keys, static_cast<const Key*>(keys), encoding_buffer_);
const uint32_t values_elem_cnt = n_keys * num_elem_per_value_;
RUN_CUDA_KERNEL((FusedHalfUpdateKernel<Elem, Index, pack_size>), stream,
values_elem_cnt / pack_size, num_elem_per_value_, values_, values_elem_cnt,
encoding_buffer_, static_cast<const Elem*>(values),
static_cast<const half*>(update), lr, scale);
}
template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::Dump(ep::Stream* stream, uint64_t start_key_index,
uint64_t end_key_index, uint32_t* n_dumped,
void* keys, void* values) {
encoder_.Dump(stream, start_key_index, end_key_index, n_dumped, static_cast<Key*>(keys),
encoding_buffer_);
RUN_CUDA_KERNEL((DumpValueKernel<Key, Elem, Index>), stream,
num_elem_per_value_ * (end_key_index - start_key_index), num_elem_per_value_,
n_dumped, encoding_buffer_, values_, static_cast<Elem*>(values));
}
template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::Clear() {
encoder_.Clear();
}
template<typename Key, typename Index>
std::unique_ptr<Cache> DispatchValueType(const CacheOptions& options) {
if (options.value_type == DataType::kFloat) {
const size_t value_elem_cnt = options.value_size / sizeof(float);
const size_t half_warp = 16;
if (value_elem_cnt % 4 == 0 && value_elem_cnt / 4 > half_warp) {
return std::unique_ptr<Cache>(new CacheImpl<Key, float, Index, 4>(options));
} else if (value_elem_cnt % 2 == 0 && value_elem_cnt / 2 > half_warp) {
return std::unique_ptr<Cache>(new CacheImpl<Key, float, Index, 2>(options));
} else {
return std::unique_ptr<Cache>(new CacheImpl<Key, float, Index, 1>(options));
}
} else if (options.value_size % sizeof(ulonglong2) == 0) {
return std::unique_ptr<Cache>(new CacheImpl<Key, ulonglong2, Index, 1>(options));
} else if (options.value_size % sizeof(uint64_t) == 0) {
return std::unique_ptr<Cache>(new CacheImpl<Key, uint64_t, Index, 1>(options));
} else if (options.value_size % sizeof(uint32_t) == 0) {
return std::unique_ptr<Cache>(new CacheImpl<Key, uint32_t, Index, 1>(options));
} else if (options.value_size % sizeof(uint16_t) == 0) {
return std::unique_ptr<Cache>(new CacheImpl<Key, uint16_t, Index, 1>(options));
} else {
return std::unique_ptr<Cache>(new CacheImpl<Key, uint8_t, Index, 1>(options));
}
}
template<typename Index>
std::unique_ptr<Cache> DispatchKeyType(const CacheOptions& options) {
if (options.key_size == sizeof(Key32)) {
return DispatchValueType<Key32, Index>(options);
} else if (options.key_size == sizeof(Key64)) {
return DispatchValueType<Key64, Index>(options);
} else {
UNIMPLEMENTED();
return nullptr;
}
}
std::unique_ptr<Cache> DispatchIndexType(const CacheOptions& options) {
const int64_t table_capacity = static_cast<double>(options.capacity) / options.load_factor;
if (table_capacity >= (1ULL << 31ULL)) {
return DispatchKeyType<uint64_t>(options);
} else {
return DispatchKeyType<uint32_t>(options);
}
}
} // namespace
std::unique_ptr<Cache> NewFullCache(const CacheOptions& options) {
return DispatchIndexType(options);
}
} // namespace embedding
} // namespace oneflow
\ No newline at end of file
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EMBEDDING_HASH_FUNCTION_HIP_H_
#define ONEFLOW_CORE_EMBEDDING_HASH_FUNCTION_HIP_H_
#include <stdint.h>
#include "oneflow/core/common/data_type.h"
namespace oneflow {
namespace embedding {
namespace {
// From https://github.com/Cyan4973/xxHash/blob/dev/xxhash.h
static const uint64_t PRIME64_1 =
0x9E3779B185EBCA87ULL; // 0b1001111000110111011110011011000110000101111010111100101010000111
static const uint64_t PRIME64_2 =
0xC2B2AE3D27D4EB4FULL; // 0b1100001010110010101011100011110100100111110101001110101101001111
static const uint64_t PRIME64_3 =
0x165667B19E3779F9ULL; // 0b0001011001010110011001111011000110011110001101110111100111111001
static const uint64_t PRIME64_4 =
0x85EBCA77C2B2AE63ULL; // 0b1000010111101011110010100111011111000010101100101010111001100011
static const uint64_t PRIME64_5 =
0x27D4EB2F165667C5ULL; // 0b0010011111010100111010110010111100010110010101100110011111000101
#define XXH_rotl64(x, r) (((x) << (r)) | ((x) >> (64 - (r))))
OF_DEVICE_FUNC uint64_t XXH64_round(uint64_t acc, uint64_t input) {
acc += input * PRIME64_2;
acc = XXH_rotl64(acc, 31);
acc *= PRIME64_1;
return acc;
}
OF_DEVICE_FUNC uint64_t xxh64_uint64(uint64_t v, uint64_t seed) {
uint64_t acc = seed + PRIME64_5;
acc += sizeof(uint64_t);
acc = acc ^ XXH64_round(0, v);
acc = XXH_rotl64(acc, 27) * PRIME64_1;
acc = acc + PRIME64_4;
acc ^= (acc >> 33);
acc = acc * PRIME64_2;
acc = acc ^ (acc >> 29);
acc = acc * PRIME64_3;
acc = acc ^ (acc >> 32);
return acc;
}
static const size_t kShardingHashSeed = 1;
static const size_t kLocalUniqueHashSeed = 2;
static const size_t kGlobalUniqueHashSeed = 3;
static const size_t kFullCacheHashSeed = 4;
static const size_t kLruCacheHashSeed = 5;
} // namespace
struct ShardingHash {
OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kShardingHashSeed); }
OF_DEVICE_FUNC size_t operator()(uint32_t v) { return xxh64_uint64(v, kShardingHashSeed); }
OF_DEVICE_FUNC size_t operator()(int32_t v) {
return xxh64_uint64(static_cast<uint32_t>(v), kShardingHashSeed);
}
OF_DEVICE_FUNC size_t operator()(int64_t v) {
return xxh64_uint64(static_cast<uint64_t>(v), kShardingHashSeed);
}
};
struct LocalUniqueHash {
OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kLocalUniqueHashSeed); }
};
struct GlobalUniqueHash {
OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kGlobalUniqueHashSeed); }
};
struct FullCacheHash {
OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kFullCacheHashSeed); }
};
struct LruCacheHash {
OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kLruCacheHashSeed); }
};
} // namespace embedding
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EMBEDDING_HASH_FUNCTION_HIP_H_
#define ONEFLOW_CORE_EMBEDDING_HASH_FUNCTION_HIP_H_
#include <stdint.h>
#include "oneflow/core/common/data_type.h"
namespace oneflow {
namespace embedding {
namespace {
// From https://github.com/Cyan4973/xxHash/blob/dev/xxhash.h
static const uint64_t PRIME64_1 =
0x9E3779B185EBCA87ULL; // 0b1001111000110111011110011011000110000101111010111100101010000111
static const uint64_t PRIME64_2 =
0xC2B2AE3D27D4EB4FULL; // 0b1100001010110010101011100011110100100111110101001110101101001111
static const uint64_t PRIME64_3 =
0x165667B19E3779F9ULL; // 0b0001011001010110011001111011000110011110001101110111100111111001
static const uint64_t PRIME64_4 =
0x85EBCA77C2B2AE63ULL; // 0b1000010111101011110010100111011111000010101100101010111001100011
static const uint64_t PRIME64_5 =
0x27D4EB2F165667C5ULL; // 0b0010011111010100111010110010111100010110010101100110011111000101
#define XXH_rotl64(x, r) (((x) << (r)) | ((x) >> (64 - (r))))
OF_DEVICE_FUNC uint64_t XXH64_round(uint64_t acc, uint64_t input) {
acc += input * PRIME64_2;
acc = XXH_rotl64(acc, 31);
acc *= PRIME64_1;
return acc;
}
OF_DEVICE_FUNC uint64_t xxh64_uint64(uint64_t v, uint64_t seed) {
uint64_t acc = seed + PRIME64_5;
acc += sizeof(uint64_t);
acc = acc ^ XXH64_round(0, v);
acc = XXH_rotl64(acc, 27) * PRIME64_1;
acc = acc + PRIME64_4;
acc ^= (acc >> 33);
acc = acc * PRIME64_2;
acc = acc ^ (acc >> 29);
acc = acc * PRIME64_3;
acc = acc ^ (acc >> 32);
return acc;
}
static const size_t kShardingHashSeed = 1;
static const size_t kLocalUniqueHashSeed = 2;
static const size_t kGlobalUniqueHashSeed = 3;
static const size_t kFullCacheHashSeed = 4;
static const size_t kLruCacheHashSeed = 5;
} // namespace
struct ShardingHash {
OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kShardingHashSeed); }
OF_DEVICE_FUNC size_t operator()(uint32_t v) { return xxh64_uint64(v, kShardingHashSeed); }
OF_DEVICE_FUNC size_t operator()(int32_t v) {
return xxh64_uint64(static_cast<uint32_t>(v), kShardingHashSeed);
}
OF_DEVICE_FUNC size_t operator()(int64_t v) {
return xxh64_uint64(static_cast<uint64_t>(v), kShardingHashSeed);
}
};
struct LocalUniqueHash {
OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kLocalUniqueHashSeed); }
};
struct GlobalUniqueHash {
OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kGlobalUniqueHashSeed); }
};
struct FullCacheHash {
OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kFullCacheHashSeed); }
};
struct LruCacheHash {
OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kLruCacheHashSeed); }
};
} // namespace embedding
} // namespace oneflow
#endif // ONEFLOW_CORE_EMBEDDING_HASH_FUNCTION_HIP_H_
\ No newline at end of file
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Inspired by https://github.com/NVIDIA-Merlin/HugeCTR/blob/master/gpu_cache/src/nv_gpu_cache.cu
#include "oneflow/core/embedding/lru_cache.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/embedding/hash_functions.hip.h"
#include <new>
#include <hip/hip_runtime.h>
namespace oneflow {
namespace embedding {
namespace {
constexpr int kWarpSize = 64;
constexpr int kNumWarpPerBlock = 2;
constexpr int kBlockSize = kNumWarpPerBlock * kWarpSize;
constexpr unsigned long long int kFullMask = 0xFFFFFFFFFFFFFFFFU;
ep::CudaLaunchConfig GetLaunchConfig(uint32_t n_keys) {
return ep::CudaLaunchConfig((n_keys + kNumWarpPerBlock - 1) / kNumWarpPerBlock,
kWarpSize * kNumWarpPerBlock, 0);
}
struct ThreadContext {
__device__ ThreadContext() {
const uint32_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;
global_warp_id = global_thread_id / kWarpSize;
warp_id_in_block = global_warp_id % kNumWarpPerBlock; // NOLINT
num_warps = gridDim.x * kNumWarpPerBlock; // NOLINT
lane_id = global_thread_id % kWarpSize;
}
uint32_t global_warp_id;
uint32_t warp_id_in_block;
uint32_t num_warps;
uint32_t lane_id;
};
class WarpMutexAtomicImpl {
public:
OF_DISALLOW_COPY_AND_MOVE(WarpMutexAtomicImpl);
__device__ WarpMutexAtomicImpl() : flag_(0) {}
__device__ ~WarpMutexAtomicImpl() = default;
__device__ void Lock(const ThreadContext& thread_ctx) {
if (thread_ctx.lane_id == 0) {
while (atomicCAS(&flag_, 0, 1) != 0)
;
}
__threadfence();
__syncthreads();
}
__device__ void Unlock(const ThreadContext& thread_ctx) {
__syncthreads();
__threadfence();
if (thread_ctx.lane_id == 0) { atomicExch(&flag_, 0); }
}
private:
int32_t flag_;
};
template<typename Key, typename Elem>
struct LruCacheContext {
Key* keys;
Elem* lines;
uint8_t* ages;
void* mutex;
uint64_t n_set;
uint32_t line_size;
CacheOptions::MemoryKind value_memory_kind;
};
__global__ void InitCacheSetMutex(uint32_t n_set, void* mutex) {
using WarpMutex = WarpMutexAtomicImpl;
const uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n_set) { new (reinterpret_cast<WarpMutex*>(mutex) + idx) WarpMutex; }
}
template<typename Key, typename Elem>
void ClearLruCacheContext(LruCacheContext<Key, Elem>* ctx) {
OF_CUDA_CHECK(hipMemset(ctx->keys, 0, ctx->n_set * kWarpSize * sizeof(Key)));
OF_CUDA_CHECK(hipMemset(ctx->ages, 0, ctx->n_set * kWarpSize * sizeof(uint8_t)));
InitCacheSetMutex<<<(ctx->n_set - 1 + 256) / 256, 256>>>(ctx->n_set, ctx->mutex);
}
template<typename Key, typename Elem>
void InitLruCacheContext(const CacheOptions& options, LruCacheContext<Key, Elem>* ctx) {
const size_t keys_size_per_set = kWarpSize * sizeof(Key);
const uint32_t line_size = options.value_size / sizeof(Elem);
const size_t lines_size_per_set = kWarpSize * line_size * sizeof(Elem);
const size_t ages_size_per_set = kWarpSize * sizeof(uint8_t);
int device = 0;
OF_CUDA_CHECK(hipGetDevice(&device));
int major = 0;
OF_CUDA_CHECK(hipDeviceGetAttribute(&major, hipDeviceAttributeComputeCapabilityMajor, device));
size_t mutex_size_per_set = 0;
mutex_size_per_set = sizeof(WarpMutexAtomicImpl);
const size_t n_set = (options.capacity - 1 + kWarpSize) / kWarpSize;
CHECK_GT(n_set, 0);
ctx->n_set = n_set;
ctx->line_size = line_size;
const size_t keys_size = n_set * keys_size_per_set;
OF_CUDA_CHECK(hipMalloc(&(ctx->keys), keys_size));
const size_t lines_size = n_set * lines_size_per_set;
if (options.value_memory_kind == CacheOptions::MemoryKind::kDevice) {
OF_CUDA_CHECK(hipMalloc(&(ctx->lines), lines_size));
} else if (options.value_memory_kind == CacheOptions::MemoryKind::kHost) {
if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_DISABLE_NUMA_AWARE_ALLOCATION", false)) {
OF_CUDA_CHECK(hipMallocHost(reinterpret_cast<void **>(&(ctx->lines)), lines_size));
} else {
OF_CUDA_CHECK(
NumaAwareCudaMallocHost(device, reinterpret_cast<void**>(&ctx->lines), lines_size));
}
} else {
UNIMPLEMENTED();
}
ctx->value_memory_kind = options.value_memory_kind;
const size_t ages_size = n_set * ages_size_per_set;
OF_CUDA_CHECK(hipMalloc(&(ctx->ages), ages_size));
const size_t mutex_size = n_set * mutex_size_per_set;
OF_CUDA_CHECK(hipMalloc(&(ctx->mutex), mutex_size));
ClearLruCacheContext(ctx);
}
template<typename Key, typename Elem>
void DestroyLruCacheContext(LruCacheContext<Key, Elem>* ctx) {
OF_CUDA_CHECK(hipFree(ctx->keys));
if (ctx->value_memory_kind == CacheOptions::MemoryKind::kDevice) {
OF_CUDA_CHECK(hipFree(ctx->lines));
} else if (ctx->value_memory_kind == CacheOptions::MemoryKind::kHost) {
OF_CUDA_CHECK(hipHostFree(ctx->lines));
} else {
UNIMPLEMENTED();
}
OF_CUDA_CHECK(hipFree(ctx->ages));
OF_CUDA_CHECK(hipFree(ctx->mutex));
}
template<typename Key, typename Elem>
struct SetContext {
using WarpMutex = WarpMutexAtomicImpl;
__device__ SetContext(const LruCacheContext<Key, Elem>& ctx, uint32_t set_id)
: keys(ctx.keys + set_id * kWarpSize),
mutex(reinterpret_cast<WarpMutex*>(ctx.mutex) + set_id),
ages(ctx.ages + set_id * kWarpSize),
lines(ctx.lines + set_id * kWarpSize * ctx.line_size) {}
__device__ int Lookup(const ThreadContext& thread_ctx, Key key) {
const Key lane_key = keys[thread_ctx.lane_id];
const int lane_age = ages[thread_ctx.lane_id];
const bool lane_hit = (lane_key == key && lane_age != 0);
const unsigned long long int hit_mask = __ballot(lane_hit);
if (hit_mask != 0) {
return __ffs(static_cast<int>(hit_mask)) - 1;
} else {
return -1;
}
}
__device__ void Read(const LruCacheContext<Key, Elem>& cache_ctx, const ThreadContext& thread_ctx,
int way, Elem* line) {
const Elem* from_line = lines + way * cache_ctx.line_size;
for (int i = thread_ctx.lane_id; i < cache_ctx.line_size; i += kWarpSize) {
line[i] = from_line[i];
}
}
__device__ int InsertWithoutEvicting(const LruCacheContext<Key, Elem>& cache_ctx,
const ThreadContext& thread_ctx, Key key) {
int insert_way = -1;
const Key lane_key = keys[thread_ctx.lane_id];
int lane_age = ages[thread_ctx.lane_id];
const unsigned long long int hit_mask = __ballot(lane_key == key && lane_age != 0);
if (hit_mask != 0) {
insert_way = __ffs(static_cast<int>(hit_mask)) - 1;
const int insert_way_age = __shfl(lane_age, insert_way);
if (lane_age > insert_way_age) {
lane_age -= 1;
} else if (thread_ctx.lane_id == insert_way) {
lane_age = kWarpSize;
}
__syncthreads();
}
if (insert_way == -1) {
const unsigned long long int valid_mask = __ballot(lane_age != 0);
if (valid_mask != kFullMask) {
insert_way = __popc(static_cast<int>(valid_mask));
if (lane_age > 0) {
lane_age -= 1;
} else if (thread_ctx.lane_id == insert_way) {
lane_age = kWarpSize;
keys[insert_way] = key;
}
__syncthreads();
}
}
if (insert_way != -1) { ages[thread_ctx.lane_id] = lane_age; }
return insert_way;
}
__device__ void Evict(const LruCacheContext<Key, Elem>& cache_ctx,
const ThreadContext& thread_ctx, Key key, int* way, Key* evicted_key) {
const Key lane_key = keys[thread_ctx.lane_id];
int lane_age = ages[thread_ctx.lane_id];
const int insert_way = __ffs(static_cast<int>(__ballot(lane_age == 1))) - 1;
*evicted_key = __shfl(lane_key, insert_way);
if (thread_ctx.lane_id == insert_way) {
keys[insert_way] = key;
lane_age = kWarpSize;
} else if (lane_age > 1) {
lane_age -= 1;
}
__syncthreads();
ages[thread_ctx.lane_id] = lane_age;
*way = insert_way;
}
__device__ void Write(const LruCacheContext<Key, Elem>& cache_ctx,
const ThreadContext& thread_ctx, int way, const Elem* line) {
Elem* to_line = lines + way * cache_ctx.line_size;
for (int i = thread_ctx.lane_id; i < cache_ctx.line_size; i += kWarpSize) {
to_line[i] = line[i];
}
}
__device__ void Lock(const ThreadContext& thread_ctx) { mutex->Lock(thread_ctx); }
__device__ void Unlock(const ThreadContext& thread_ctx) { mutex->Unlock(thread_ctx); }
Key* keys;
Elem* lines;
uint8_t* ages;
WarpMutex* mutex;
};
template<typename Key, typename Elem, bool test_only>
__global__ void GetKernel(LruCacheContext<Key, Elem> cache_ctx, uint32_t num_keys, const Key* keys,
Elem* values, uint32_t* n_missing_keys, Key* missing_keys,
uint32_t* missing_indices) {
ThreadContext thread_ctx{};
__shared__ Key block_keys[kNumWarpPerBlock][kWarpSize];
__shared__ size_t block_set_ids[kNumWarpPerBlock][kWarpSize];
for (uint32_t batch_offset = thread_ctx.global_warp_id * kWarpSize; batch_offset < num_keys;
batch_offset += thread_ctx.num_warps * kWarpSize) {
const uint32_t n_batch_keys = min(kWarpSize, num_keys - batch_offset);
if (thread_ctx.lane_id < n_batch_keys) {
const Key key = keys[batch_offset + thread_ctx.lane_id];
const size_t hash = LruCacheHash()(key);
const uint32_t set_id = hash % cache_ctx.n_set;
block_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = key;
block_set_ids[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = set_id;
}
__syncthreads();
uint32_t n_warp_missing = 0;
Key warp_missing_key = 0;
uint32_t warp_missing_index = 0;
for (uint32_t i = 0; i < n_batch_keys; ++i) {
const uint32_t key_idx = batch_offset + i;
const Key key = block_keys[thread_ctx.warp_id_in_block][i];
const size_t set_id = block_set_ids[thread_ctx.warp_id_in_block][i];
SetContext<Key, Elem> set_ctx(cache_ctx, set_id);
const int way = set_ctx.Lookup(thread_ctx, key);
if (way < 0) {
if (thread_ctx.lane_id == n_warp_missing) {
warp_missing_key = key;
warp_missing_index = key_idx;
}
__syncthreads();
n_warp_missing += 1;
} else if (!test_only) {
set_ctx.Read(cache_ctx, thread_ctx, way, values + key_idx * cache_ctx.line_size);
}
}
if (n_warp_missing > 0) {
uint32_t base_missing_idx = 0;
if (thread_ctx.lane_id == 0) { base_missing_idx = atomicAdd(n_missing_keys, n_warp_missing); }
__syncthreads();
base_missing_idx = __shfl(base_missing_idx, 0);
if (thread_ctx.lane_id < n_warp_missing) {
missing_keys[base_missing_idx + thread_ctx.lane_id] = warp_missing_key;
missing_indices[base_missing_idx + thread_ctx.lane_id] = warp_missing_index;
}
__syncthreads();
}
__syncthreads();
}
}
template<typename Key, typename Elem>
__global__ void PutWithoutEvictingKernel(LruCacheContext<Key, Elem> cache_ctx, uint32_t num_keys,
const Key* keys, const Elem* values, uint32_t* n_missing,
Key* missing_keys, uint32_t* missing_indices) {
ThreadContext thread_ctx{};
__shared__ Key block_keys[kNumWarpPerBlock][kWarpSize];
__shared__ size_t block_set_ids[kNumWarpPerBlock][kWarpSize];
for (uint32_t batch_offset = thread_ctx.global_warp_id * kWarpSize; batch_offset < num_keys;
batch_offset += thread_ctx.num_warps * kWarpSize) {
const uint32_t n_batch_keys = min(kWarpSize, num_keys - batch_offset);
if (thread_ctx.lane_id < n_batch_keys) {
const Key key = keys[batch_offset + thread_ctx.lane_id];
const size_t hash = LruCacheHash()(key);
const uint32_t set_id = hash % cache_ctx.n_set;
block_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = key;
block_set_ids[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = set_id;
}
__syncthreads();
uint32_t n_warp_missing = 0;
Key warp_missing_key = 0;
uint32_t warp_missing_index = 0;
for (uint32_t i = 0; i < n_batch_keys; ++i) {
const uint32_t key_idx = batch_offset + i;
const Key key = block_keys[thread_ctx.warp_id_in_block][i];
const size_t set_id = block_set_ids[thread_ctx.warp_id_in_block][i];
SetContext<Key, Elem> set_ctx(cache_ctx, set_id);
set_ctx.Lock(thread_ctx);
Key evicted_key = 0;
const int insert_way = set_ctx.InsertWithoutEvicting(cache_ctx, thread_ctx, key);
if (insert_way >= 0) {
set_ctx.Write(cache_ctx, thread_ctx, insert_way, values + cache_ctx.line_size * key_idx);
} else {
if (thread_ctx.lane_id == n_warp_missing) {
warp_missing_key = key;
warp_missing_index = key_idx;
}
__syncthreads();
n_warp_missing += 1;
}
set_ctx.Unlock(thread_ctx);
}
if (n_warp_missing > 0) {
uint32_t base_missing_idx = 0;
if (thread_ctx.lane_id == 0) { base_missing_idx = atomicAdd(n_missing, n_warp_missing); }
__syncthreads();
base_missing_idx = __shfl(base_missing_idx, 0);
if (thread_ctx.lane_id < n_warp_missing) {
missing_keys[base_missing_idx + thread_ctx.lane_id] = warp_missing_key;
missing_indices[base_missing_idx + thread_ctx.lane_id] = warp_missing_index;
}
__syncthreads();
}
}
}
template<typename Key, typename Elem>
__global__ void EvictKernel(LruCacheContext<Key, Elem> cache_ctx, const Key* keys,
const uint32_t* indices, const Elem* values, const uint32_t* n_evict,
Key* evicted_keys, Elem* evicted_values) {
ThreadContext thread_ctx{};
uint32_t num_evict = *n_evict;
__shared__ Key block_keys[kNumWarpPerBlock][kWarpSize];
__shared__ size_t block_set_ids[kNumWarpPerBlock][kWarpSize];
for (uint32_t batch_offset = thread_ctx.global_warp_id * kWarpSize; batch_offset < num_evict;
batch_offset += thread_ctx.num_warps * kWarpSize) {
const uint32_t n_batch_keys = min(kWarpSize, num_evict - batch_offset);
if (thread_ctx.lane_id < n_batch_keys) {
const Key key = keys[batch_offset + thread_ctx.lane_id];
const size_t hash = LruCacheHash()(key);
const uint32_t set_id = hash % cache_ctx.n_set;
block_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = key;
block_set_ids[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = set_id;
}
__syncthreads();
for (uint32_t i = 0; i < n_batch_keys; ++i) {
const uint32_t key_idx = batch_offset + i;
const Key key = block_keys[thread_ctx.warp_id_in_block][i];
const uint32_t set_id = block_set_ids[thread_ctx.warp_id_in_block][i];
SetContext<Key, Elem> set_ctx(cache_ctx, set_id);
set_ctx.Lock(thread_ctx);
int evicted_way = -1;
Key evicted_key = 0;
set_ctx.Evict(cache_ctx, thread_ctx, key, &evicted_way, &evicted_key);
if (thread_ctx.lane_id == 0) { evicted_keys[key_idx] = evicted_key; }
__syncthreads();
set_ctx.Read(cache_ctx, thread_ctx, evicted_way,
evicted_values + cache_ctx.line_size * key_idx);
set_ctx.Write(cache_ctx, thread_ctx, evicted_way,
values + cache_ctx.line_size * indices[key_idx]);
set_ctx.Unlock(thread_ctx);
}
}
}
template<typename Key, typename Elem>
__global__ void DumpKernel(LruCacheContext<Key, Elem> cache_ctx, size_t start_key_index,
size_t end_key_index, uint32_t* n_dumped, Key* keys, Elem* values) {
ThreadContext thread_ctx{};
__shared__ Key warp_keys[kNumWarpPerBlock][kWarpSize];
__shared__ uint8_t warp_ages[kNumWarpPerBlock][kWarpSize];
for (uint32_t warp_start_key_index = start_key_index + thread_ctx.global_warp_id * kWarpSize;
warp_start_key_index < end_key_index;
warp_start_key_index += thread_ctx.num_warps * kWarpSize) {
Key lane_key = 0;
uint8_t lane_age = 0;
if (warp_start_key_index + thread_ctx.lane_id < end_key_index) {
lane_key = cache_ctx.keys[warp_start_key_index + thread_ctx.lane_id];
lane_age = cache_ctx.ages[warp_start_key_index + thread_ctx.lane_id];
}
__syncthreads();
warp_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_key;
warp_ages[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_age;
const int key_count = __popc(static_cast<int>(__ballot(lane_age != 0)));
if (key_count == 0) { continue; }
uint32_t offset = 0;
if (thread_ctx.lane_id == 0) { offset = atomicAdd(n_dumped, key_count); }
offset = __shfl(offset, 0);
__syncthreads();
for (uint32_t i = 0; i < kWarpSize; ++i) {
const Key key = warp_keys[thread_ctx.warp_id_in_block][i];
const Key age = warp_ages[thread_ctx.warp_id_in_block][i];
if (age == 0) { continue; }
if (thread_ctx.lane_id == 0) { keys[offset] = key; }
__syncthreads();
for (uint32_t j = thread_ctx.lane_id; j < cache_ctx.line_size; j += kWarpSize) {
values[offset * cache_ctx.line_size + j] =
cache_ctx.lines[(warp_start_key_index + i) * cache_ctx.line_size + j];
}
__syncthreads();
offset += 1;
}
}
}
template<typename Key, typename Elem>
class LruCache : public Cache {
public:
OF_DISALLOW_COPY_AND_MOVE(LruCache);
explicit LruCache(const CacheOptions& options)
: device_index_{},
max_query_length_(0),
query_indices_buffer_(nullptr),
query_keys_buffer_(nullptr),
value_type_(options.value_type) {
OF_CUDA_CHECK(hipGetDevice(&device_index_));
InitLruCacheContext(options, &ctx_);
}
~LruCache() override {
CudaCurrentDeviceGuard guard(device_index_);
if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipFree(query_indices_buffer_));
OF_CUDA_CHECK(hipFree(query_keys_buffer_));
}
DestroyLruCacheContext(&ctx_);
}
uint32_t KeySize() const override { return sizeof(Key); }
uint32_t ValueSize() const override { return sizeof(Elem) * ctx_.line_size; }
DataType ValueType() const override { return value_type_; }
uint64_t Capacity() const override { return ctx_.n_set * kWarpSize; }
uint32_t MaxQueryLength() const override { return max_query_length_; }
void ReserveQueryLength(uint32_t query_length) override {
CudaCurrentDeviceGuard guard(device_index_);
if (query_length < max_query_length_) { return; }
if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipFree(query_indices_buffer_));
OF_CUDA_CHECK(hipFree(query_keys_buffer_));
}
OF_CUDA_CHECK(hipMalloc(&query_indices_buffer_, query_length * sizeof(uint32_t)));
OF_CUDA_CHECK(hipMalloc(&query_keys_buffer_, query_length * sizeof(Key)));
max_query_length_ = query_length;
}
CacheOptions::Policy Policy() const override { return CacheOptions::Policy::kLRU; }
void Test(ep::Stream* stream, uint32_t n_keys, const void* keys, uint32_t* n_missing,
void* missing_keys, uint32_t* missing_indices) override {
CHECK_LE(n_keys, max_query_length_);
auto cuda_stream = stream->As<ep::CudaStream>();
OF_CUDA_CHECK(hipMemsetAsync(n_missing, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));
if (n_keys == 0) { return; }
cuda_stream->LaunchKernel(GetKernel<Key, Elem, true>, GetLaunchConfig(n_keys), ctx_, n_keys,
static_cast<const Key*>(keys), nullptr, n_missing,
static_cast<Key*>(missing_keys), missing_indices);
}
void Get(ep::Stream* stream, uint32_t n_keys, const void* keys, void* values, uint32_t* n_missing,
void* missing_keys, uint32_t* missing_indices) override {
CHECK_LE(n_keys, max_query_length_);
auto cuda_stream = stream->As<ep::CudaStream>();
OF_CUDA_CHECK(hipMemsetAsync(n_missing, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));
if (n_keys == 0) { return; }
cuda_stream->LaunchKernel(GetKernel<Key, Elem, false>, GetLaunchConfig(n_keys), ctx_, n_keys,
static_cast<const Key*>(keys), static_cast<Elem*>(values), n_missing,
static_cast<Key*>(missing_keys), missing_indices);
}
void Put(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values,
uint32_t* n_evicted, void* evicted_keys, void* evicted_values) override {
CHECK_LE(n_keys, max_query_length_);
auto cuda_stream = stream->As<ep::CudaStream>();
OF_CUDA_CHECK(hipMemsetAsync(n_evicted, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));
if (n_keys == 0) { return; }
cuda_stream->LaunchKernel(PutWithoutEvictingKernel<Key, Elem>, GetLaunchConfig(n_keys), ctx_,
n_keys, static_cast<const Key*>(keys),
static_cast<const Elem*>(values), n_evicted, query_keys_buffer_,
query_indices_buffer_);
cuda_stream->LaunchKernel(EvictKernel<Key, Elem>, GetLaunchConfig(n_keys), ctx_,
query_keys_buffer_, query_indices_buffer_,
static_cast<const Elem*>(values), n_evicted,
static_cast<Key*>(evicted_keys), static_cast<Elem*>(evicted_values));
}
void Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index,
uint32_t* n_dumped, void* keys, void* values) override {
auto cuda_stream = stream->As<ep::CudaStream>();
OF_CUDA_CHECK(hipMemsetAsync(n_dumped, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));
const uint64_t max_dump_keys = end_key_index - start_key_index;
cuda_stream->LaunchKernel(
DumpKernel<Key, Elem>,
ep::CudaLaunchConfig((max_dump_keys + kNumWarpPerBlock - 1) / kNumWarpPerBlock, kBlockSize,
0),
ctx_, start_key_index, end_key_index, n_dumped, static_cast<Key*>(keys),
static_cast<Elem*>(values));
}
void Clear() override { ClearLruCacheContext<Key, Elem>(&ctx_); }
private:
int device_index_;
uint32_t max_query_length_;
LruCacheContext<Key, Elem> ctx_;
uint32_t* query_indices_buffer_;
Key* query_keys_buffer_;
DataType value_type_;
};
template<typename Key>
std::unique_ptr<Cache> DispatchValueType(const CacheOptions& options) {
if (options.value_size % sizeof(ulonglong2) == 0) {
return std::unique_ptr<Cache>(new LruCache<Key, ulonglong2>(options));
} else if (options.value_size % sizeof(uint64_t) == 0) {
return std::unique_ptr<Cache>(new LruCache<Key, uint64_t>(options));
} else if (options.value_size % sizeof(uint32_t) == 0) {
return std::unique_ptr<Cache>(new LruCache<Key, uint32_t>(options));
} else if (options.value_size % sizeof(uint16_t) == 0) {
return std::unique_ptr<Cache>(new LruCache<Key, uint16_t>(options));
} else {
return std::unique_ptr<Cache>(new LruCache<Key, uint8_t>(options));
}
}
std::unique_ptr<Cache> DispatchKeyType(const CacheOptions& options) {
if (options.key_size == sizeof(uint32_t)) {
return DispatchValueType<uint32_t>(options);
} else if (options.key_size == sizeof(uint64_t)) {
return DispatchValueType<uint64_t>(options);
} else {
UNIMPLEMENTED();
return nullptr;
}
}
} // namespace
std::unique_ptr<Cache> NewLruCache(const CacheOptions& options) { return DispatchKeyType(options); }
} // namespace embedding
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Inspired by https://github.com/NVIDIA-Merlin/HugeCTR/blob/master/gpu_cache/src/nv_gpu_cache.cu
#include "oneflow/core/embedding/lru_cache.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/embedding/hash_functions.hip.h"
#include <new>
#include <hip/hip_runtime.h>
namespace oneflow {
namespace embedding {
namespace {
constexpr int kWarpSize = 64;
constexpr int kNumWarpPerBlock = 2;
constexpr int kBlockSize = kNumWarpPerBlock * kWarpSize;
constexpr unsigned long long int kFullMask = 0xFFFFFFFFFFFFFFFFU;
ep::CudaLaunchConfig GetLaunchConfig(uint32_t n_keys) {
return ep::CudaLaunchConfig((n_keys + kNumWarpPerBlock - 1) / kNumWarpPerBlock,
kWarpSize * kNumWarpPerBlock, 0);
}
struct ThreadContext {
__device__ ThreadContext() {
const uint32_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;
global_warp_id = global_thread_id / kWarpSize;
warp_id_in_block = global_warp_id % kNumWarpPerBlock; // NOLINT
num_warps = gridDim.x * kNumWarpPerBlock; // NOLINT
lane_id = global_thread_id % kWarpSize;
}
uint32_t global_warp_id;
uint32_t warp_id_in_block;
uint32_t num_warps;
uint32_t lane_id;
};
class WarpMutexAtomicImpl {
public:
OF_DISALLOW_COPY_AND_MOVE(WarpMutexAtomicImpl);
__device__ WarpMutexAtomicImpl() : flag_(0) {}
__device__ ~WarpMutexAtomicImpl() = default;
__device__ void Lock(const ThreadContext& thread_ctx) {
if (thread_ctx.lane_id == 0) {
while (atomicCAS(&flag_, 0, 1) != 0)
;
}
__threadfence();
__syncthreads();
}
__device__ void Unlock(const ThreadContext& thread_ctx) {
__syncthreads();
__threadfence();
if (thread_ctx.lane_id == 0) { atomicExch(&flag_, 0); }
}
private:
int32_t flag_;
};
template<typename Key, typename Elem>
struct LruCacheContext {
Key* keys;
Elem* lines;
uint8_t* ages;
void* mutex;
uint64_t n_set;
uint32_t line_size;
CacheOptions::MemoryKind value_memory_kind;
};
__global__ void InitCacheSetMutex(uint32_t n_set, void* mutex) {
using WarpMutex = WarpMutexAtomicImpl;
const uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n_set) { new (reinterpret_cast<WarpMutex*>(mutex) + idx) WarpMutex; }
}
template<typename Key, typename Elem>
void ClearLruCacheContext(LruCacheContext<Key, Elem>* ctx) {
OF_CUDA_CHECK(hipMemset(ctx->keys, 0, ctx->n_set * kWarpSize * sizeof(Key)));
OF_CUDA_CHECK(hipMemset(ctx->ages, 0, ctx->n_set * kWarpSize * sizeof(uint8_t)));
InitCacheSetMutex<<<(ctx->n_set - 1 + 256) / 256, 256>>>(ctx->n_set, ctx->mutex);
}
template<typename Key, typename Elem>
void InitLruCacheContext(const CacheOptions& options, LruCacheContext<Key, Elem>* ctx) {
const size_t keys_size_per_set = kWarpSize * sizeof(Key);
const uint32_t line_size = options.value_size / sizeof(Elem);
const size_t lines_size_per_set = kWarpSize * line_size * sizeof(Elem);
const size_t ages_size_per_set = kWarpSize * sizeof(uint8_t);
int device = 0;
OF_CUDA_CHECK(hipGetDevice(&device));
int major = 0;
OF_CUDA_CHECK(hipDeviceGetAttribute(&major, hipDeviceAttributeComputeCapabilityMajor, device));
size_t mutex_size_per_set = 0;
mutex_size_per_set = sizeof(WarpMutexAtomicImpl);
const size_t n_set = (options.capacity - 1 + kWarpSize) / kWarpSize;
CHECK_GT(n_set, 0);
ctx->n_set = n_set;
ctx->line_size = line_size;
const size_t keys_size = n_set * keys_size_per_set;
OF_CUDA_CHECK(hipMalloc(&(ctx->keys), keys_size));
const size_t lines_size = n_set * lines_size_per_set;
if (options.value_memory_kind == CacheOptions::MemoryKind::kDevice) {
OF_CUDA_CHECK(hipMalloc(&(ctx->lines), lines_size));
} else if (options.value_memory_kind == CacheOptions::MemoryKind::kHost) {
if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_DISABLE_NUMA_AWARE_ALLOCATION", false)) {
OF_CUDA_CHECK(hipMallocHost(reinterpret_cast<void **>(&(ctx->lines)), lines_size));
} else {
OF_CUDA_CHECK(
NumaAwareCudaMallocHost(device, reinterpret_cast<void**>(&ctx->lines), lines_size));
}
} else {
UNIMPLEMENTED();
}
ctx->value_memory_kind = options.value_memory_kind;
const size_t ages_size = n_set * ages_size_per_set;
OF_CUDA_CHECK(hipMalloc(&(ctx->ages), ages_size));
const size_t mutex_size = n_set * mutex_size_per_set;
OF_CUDA_CHECK(hipMalloc(&(ctx->mutex), mutex_size));
ClearLruCacheContext(ctx);
}
template<typename Key, typename Elem>
void DestroyLruCacheContext(LruCacheContext<Key, Elem>* ctx) {
OF_CUDA_CHECK(hipFree(ctx->keys));
if (ctx->value_memory_kind == CacheOptions::MemoryKind::kDevice) {
OF_CUDA_CHECK(hipFree(ctx->lines));
} else if (ctx->value_memory_kind == CacheOptions::MemoryKind::kHost) {
OF_CUDA_CHECK(hipHostFree(ctx->lines));
} else {
UNIMPLEMENTED();
}
OF_CUDA_CHECK(hipFree(ctx->ages));
OF_CUDA_CHECK(hipFree(ctx->mutex));
}
template<typename Key, typename Elem>
struct SetContext {
using WarpMutex = WarpMutexAtomicImpl;
__device__ SetContext(const LruCacheContext<Key, Elem>& ctx, uint32_t set_id)
: keys(ctx.keys + set_id * kWarpSize),
mutex(reinterpret_cast<WarpMutex*>(ctx.mutex) + set_id),
ages(ctx.ages + set_id * kWarpSize),
lines(ctx.lines + set_id * kWarpSize * ctx.line_size) {}
__device__ int Lookup(const ThreadContext& thread_ctx, Key key) {
const Key lane_key = keys[thread_ctx.lane_id];
const int lane_age = ages[thread_ctx.lane_id];
const bool lane_hit = (lane_key == key && lane_age != 0);
const unsigned long long int hit_mask = __ballot(lane_hit);
if (hit_mask != 0) {
return __ffs(static_cast<int>(hit_mask)) - 1;
} else {
return -1;
}
}
__device__ void Read(const LruCacheContext<Key, Elem>& cache_ctx, const ThreadContext& thread_ctx,
int way, Elem* line) {
const Elem* from_line = lines + way * cache_ctx.line_size;
for (int i = thread_ctx.lane_id; i < cache_ctx.line_size; i += kWarpSize) {
line[i] = from_line[i];
}
}
__device__ int InsertWithoutEvicting(const LruCacheContext<Key, Elem>& cache_ctx,
const ThreadContext& thread_ctx, Key key) {
int insert_way = -1;
const Key lane_key = keys[thread_ctx.lane_id];
int lane_age = ages[thread_ctx.lane_id];
const unsigned long long int hit_mask = __ballot(lane_key == key && lane_age != 0);
if (hit_mask != 0) {
insert_way = __ffs(static_cast<int>(hit_mask)) - 1;
const int insert_way_age = __shfl(lane_age, insert_way);
if (lane_age > insert_way_age) {
lane_age -= 1;
} else if (thread_ctx.lane_id == insert_way) {
lane_age = kWarpSize;
}
__syncthreads();
}
if (insert_way == -1) {
const unsigned long long int valid_mask = __ballot(lane_age != 0);
if (valid_mask != kFullMask) {
insert_way = __popc(static_cast<int>(valid_mask));
if (lane_age > 0) {
lane_age -= 1;
} else if (thread_ctx.lane_id == insert_way) {
lane_age = kWarpSize;
keys[insert_way] = key;
}
__syncthreads();
}
}
if (insert_way != -1) { ages[thread_ctx.lane_id] = lane_age; }
return insert_way;
}
__device__ void Evict(const LruCacheContext<Key, Elem>& cache_ctx,
const ThreadContext& thread_ctx, Key key, int* way, Key* evicted_key) {
const Key lane_key = keys[thread_ctx.lane_id];
int lane_age = ages[thread_ctx.lane_id];
const int insert_way = __ffs(static_cast<int>(__ballot(lane_age == 1))) - 1;
*evicted_key = __shfl(lane_key, insert_way);
if (thread_ctx.lane_id == insert_way) {
keys[insert_way] = key;
lane_age = kWarpSize;
} else if (lane_age > 1) {
lane_age -= 1;
}
__syncthreads();
ages[thread_ctx.lane_id] = lane_age;
*way = insert_way;
}
__device__ void Write(const LruCacheContext<Key, Elem>& cache_ctx,
const ThreadContext& thread_ctx, int way, const Elem* line) {
Elem* to_line = lines + way * cache_ctx.line_size;
for (int i = thread_ctx.lane_id; i < cache_ctx.line_size; i += kWarpSize) {
to_line[i] = line[i];
}
}
__device__ void Lock(const ThreadContext& thread_ctx) { mutex->Lock(thread_ctx); }
__device__ void Unlock(const ThreadContext& thread_ctx) { mutex->Unlock(thread_ctx); }
Key* keys;
Elem* lines;
uint8_t* ages;
WarpMutex* mutex;
};
template<typename Key, typename Elem, bool test_only>
__global__ void GetKernel(LruCacheContext<Key, Elem> cache_ctx, uint32_t num_keys, const Key* keys,
Elem* values, uint32_t* n_missing_keys, Key* missing_keys,
uint32_t* missing_indices) {
ThreadContext thread_ctx{};
__shared__ Key block_keys[kNumWarpPerBlock][kWarpSize];
__shared__ size_t block_set_ids[kNumWarpPerBlock][kWarpSize];
for (uint32_t batch_offset = thread_ctx.global_warp_id * kWarpSize; batch_offset < num_keys;
batch_offset += thread_ctx.num_warps * kWarpSize) {
const uint32_t n_batch_keys = min(kWarpSize, num_keys - batch_offset);
if (thread_ctx.lane_id < n_batch_keys) {
const Key key = keys[batch_offset + thread_ctx.lane_id];
const size_t hash = LruCacheHash()(key);
const uint32_t set_id = hash % cache_ctx.n_set;
block_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = key;
block_set_ids[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = set_id;
}
__syncthreads();
uint32_t n_warp_missing = 0;
Key warp_missing_key = 0;
uint32_t warp_missing_index = 0;
for (uint32_t i = 0; i < n_batch_keys; ++i) {
const uint32_t key_idx = batch_offset + i;
const Key key = block_keys[thread_ctx.warp_id_in_block][i];
const size_t set_id = block_set_ids[thread_ctx.warp_id_in_block][i];
SetContext<Key, Elem> set_ctx(cache_ctx, set_id);
const int way = set_ctx.Lookup(thread_ctx, key);
if (way < 0) {
if (thread_ctx.lane_id == n_warp_missing) {
warp_missing_key = key;
warp_missing_index = key_idx;
}
__syncthreads();
n_warp_missing += 1;
} else if (!test_only) {
set_ctx.Read(cache_ctx, thread_ctx, way, values + key_idx * cache_ctx.line_size);
}
}
if (n_warp_missing > 0) {
uint32_t base_missing_idx = 0;
if (thread_ctx.lane_id == 0) { base_missing_idx = atomicAdd(n_missing_keys, n_warp_missing); }
__syncthreads();
base_missing_idx = __shfl(base_missing_idx, 0);
if (thread_ctx.lane_id < n_warp_missing) {
missing_keys[base_missing_idx + thread_ctx.lane_id] = warp_missing_key;
missing_indices[base_missing_idx + thread_ctx.lane_id] = warp_missing_index;
}
__syncthreads();
}
__syncthreads();
}
}
template<typename Key, typename Elem>
__global__ void PutWithoutEvictingKernel(LruCacheContext<Key, Elem> cache_ctx, uint32_t num_keys,
const Key* keys, const Elem* values, uint32_t* n_missing,
Key* missing_keys, uint32_t* missing_indices) {
ThreadContext thread_ctx{};
__shared__ Key block_keys[kNumWarpPerBlock][kWarpSize];
__shared__ size_t block_set_ids[kNumWarpPerBlock][kWarpSize];
for (uint32_t batch_offset = thread_ctx.global_warp_id * kWarpSize; batch_offset < num_keys;
batch_offset += thread_ctx.num_warps * kWarpSize) {
const uint32_t n_batch_keys = min(kWarpSize, num_keys - batch_offset);
if (thread_ctx.lane_id < n_batch_keys) {
const Key key = keys[batch_offset + thread_ctx.lane_id];
const size_t hash = LruCacheHash()(key);
const uint32_t set_id = hash % cache_ctx.n_set;
block_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = key;
block_set_ids[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = set_id;
}
__syncthreads();
uint32_t n_warp_missing = 0;
Key warp_missing_key = 0;
uint32_t warp_missing_index = 0;
for (uint32_t i = 0; i < n_batch_keys; ++i) {
const uint32_t key_idx = batch_offset + i;
const Key key = block_keys[thread_ctx.warp_id_in_block][i];
const size_t set_id = block_set_ids[thread_ctx.warp_id_in_block][i];
SetContext<Key, Elem> set_ctx(cache_ctx, set_id);
set_ctx.Lock(thread_ctx);
Key evicted_key = 0;
const int insert_way = set_ctx.InsertWithoutEvicting(cache_ctx, thread_ctx, key);
if (insert_way >= 0) {
set_ctx.Write(cache_ctx, thread_ctx, insert_way, values + cache_ctx.line_size * key_idx);
} else {
if (thread_ctx.lane_id == n_warp_missing) {
warp_missing_key = key;
warp_missing_index = key_idx;
}
__syncthreads();
n_warp_missing += 1;
}
set_ctx.Unlock(thread_ctx);
}
if (n_warp_missing > 0) {
uint32_t base_missing_idx = 0;
if (thread_ctx.lane_id == 0) { base_missing_idx = atomicAdd(n_missing, n_warp_missing); }
__syncthreads();
base_missing_idx = __shfl(base_missing_idx, 0);
if (thread_ctx.lane_id < n_warp_missing) {
missing_keys[base_missing_idx + thread_ctx.lane_id] = warp_missing_key;
missing_indices[base_missing_idx + thread_ctx.lane_id] = warp_missing_index;
}
__syncthreads();
}
}
}
template<typename Key, typename Elem>
__global__ void EvictKernel(LruCacheContext<Key, Elem> cache_ctx, const Key* keys,
const uint32_t* indices, const Elem* values, const uint32_t* n_evict,
Key* evicted_keys, Elem* evicted_values) {
ThreadContext thread_ctx{};
uint32_t num_evict = *n_evict;
__shared__ Key block_keys[kNumWarpPerBlock][kWarpSize];
__shared__ size_t block_set_ids[kNumWarpPerBlock][kWarpSize];
for (uint32_t batch_offset = thread_ctx.global_warp_id * kWarpSize; batch_offset < num_evict;
batch_offset += thread_ctx.num_warps * kWarpSize) {
const uint32_t n_batch_keys = min(kWarpSize, num_evict - batch_offset);
if (thread_ctx.lane_id < n_batch_keys) {
const Key key = keys[batch_offset + thread_ctx.lane_id];
const size_t hash = LruCacheHash()(key);
const uint32_t set_id = hash % cache_ctx.n_set;
block_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = key;
block_set_ids[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = set_id;
}
__syncthreads();
for (uint32_t i = 0; i < n_batch_keys; ++i) {
const uint32_t key_idx = batch_offset + i;
const Key key = block_keys[thread_ctx.warp_id_in_block][i];
const uint32_t set_id = block_set_ids[thread_ctx.warp_id_in_block][i];
SetContext<Key, Elem> set_ctx(cache_ctx, set_id);
set_ctx.Lock(thread_ctx);
int evicted_way = -1;
Key evicted_key = 0;
set_ctx.Evict(cache_ctx, thread_ctx, key, &evicted_way, &evicted_key);
if (thread_ctx.lane_id == 0) { evicted_keys[key_idx] = evicted_key; }
__syncthreads();
set_ctx.Read(cache_ctx, thread_ctx, evicted_way,
evicted_values + cache_ctx.line_size * key_idx);
set_ctx.Write(cache_ctx, thread_ctx, evicted_way,
values + cache_ctx.line_size * indices[key_idx]);
set_ctx.Unlock(thread_ctx);
}
}
}
template<typename Key, typename Elem>
__global__ void DumpKernel(LruCacheContext<Key, Elem> cache_ctx, size_t start_key_index,
size_t end_key_index, uint32_t* n_dumped, Key* keys, Elem* values) {
ThreadContext thread_ctx{};
__shared__ Key warp_keys[kNumWarpPerBlock][kWarpSize];
__shared__ uint8_t warp_ages[kNumWarpPerBlock][kWarpSize];
for (uint32_t warp_start_key_index = start_key_index + thread_ctx.global_warp_id * kWarpSize;
warp_start_key_index < end_key_index;
warp_start_key_index += thread_ctx.num_warps * kWarpSize) {
Key lane_key = 0;
uint8_t lane_age = 0;
if (warp_start_key_index + thread_ctx.lane_id < end_key_index) {
lane_key = cache_ctx.keys[warp_start_key_index + thread_ctx.lane_id];
lane_age = cache_ctx.ages[warp_start_key_index + thread_ctx.lane_id];
}
__syncthreads();
warp_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_key;
warp_ages[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_age;
const int key_count = __popc(static_cast<int>(__ballot(lane_age != 0)));
if (key_count == 0) { continue; }
uint32_t offset = 0;
if (thread_ctx.lane_id == 0) { offset = atomicAdd(n_dumped, key_count); }
offset = __shfl(offset, 0);
__syncthreads();
for (uint32_t i = 0; i < kWarpSize; ++i) {
const Key key = warp_keys[thread_ctx.warp_id_in_block][i];
const Key age = warp_ages[thread_ctx.warp_id_in_block][i];
if (age == 0) { continue; }
if (thread_ctx.lane_id == 0) { keys[offset] = key; }
__syncthreads();
for (uint32_t j = thread_ctx.lane_id; j < cache_ctx.line_size; j += kWarpSize) {
values[offset * cache_ctx.line_size + j] =
cache_ctx.lines[(warp_start_key_index + i) * cache_ctx.line_size + j];
}
__syncthreads();
offset += 1;
}
}
}
template<typename Key, typename Elem>
class LruCache : public Cache {
public:
OF_DISALLOW_COPY_AND_MOVE(LruCache);
explicit LruCache(const CacheOptions& options)
: device_index_{},
max_query_length_(0),
query_indices_buffer_(nullptr),
query_keys_buffer_(nullptr),
value_type_(options.value_type) {
OF_CUDA_CHECK(hipGetDevice(&device_index_));
InitLruCacheContext(options, &ctx_);
}
~LruCache() override {
CudaCurrentDeviceGuard guard(device_index_);
if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipFree(query_indices_buffer_));
OF_CUDA_CHECK(hipFree(query_keys_buffer_));
}
DestroyLruCacheContext(&ctx_);
}
uint32_t KeySize() const override { return sizeof(Key); }
uint32_t ValueSize() const override { return sizeof(Elem) * ctx_.line_size; }
DataType ValueType() const override { return value_type_; }
uint64_t Capacity() const override { return ctx_.n_set * kWarpSize; }
uint32_t MaxQueryLength() const override { return max_query_length_; }
void ReserveQueryLength(uint32_t query_length) override {
CudaCurrentDeviceGuard guard(device_index_);
if (query_length < max_query_length_) { return; }
if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipFree(query_indices_buffer_));
OF_CUDA_CHECK(hipFree(query_keys_buffer_));
}
OF_CUDA_CHECK(hipMalloc(&query_indices_buffer_, query_length * sizeof(uint32_t)));
OF_CUDA_CHECK(hipMalloc(&query_keys_buffer_, query_length * sizeof(Key)));
max_query_length_ = query_length;
}
CacheOptions::Policy Policy() const override { return CacheOptions::Policy::kLRU; }
void Test(ep::Stream* stream, uint32_t n_keys, const void* keys, uint32_t* n_missing,
void* missing_keys, uint32_t* missing_indices) override {
CHECK_LE(n_keys, max_query_length_);
auto cuda_stream = stream->As<ep::CudaStream>();
OF_CUDA_CHECK(hipMemsetAsync(n_missing, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));
if (n_keys == 0) { return; }
cuda_stream->LaunchKernel(GetKernel<Key, Elem, true>, GetLaunchConfig(n_keys), ctx_, n_keys,
static_cast<const Key*>(keys), nullptr, n_missing,
static_cast<Key*>(missing_keys), missing_indices);
}
void Get(ep::Stream* stream, uint32_t n_keys, const void* keys, void* values, uint32_t* n_missing,
void* missing_keys, uint32_t* missing_indices) override {
CHECK_LE(n_keys, max_query_length_);
auto cuda_stream = stream->As<ep::CudaStream>();
OF_CUDA_CHECK(hipMemsetAsync(n_missing, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));
if (n_keys == 0) { return; }
cuda_stream->LaunchKernel(GetKernel<Key, Elem, false>, GetLaunchConfig(n_keys), ctx_, n_keys,
static_cast<const Key*>(keys), static_cast<Elem*>(values), n_missing,
static_cast<Key*>(missing_keys), missing_indices);
}
void Put(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values,
uint32_t* n_evicted, void* evicted_keys, void* evicted_values) override {
CHECK_LE(n_keys, max_query_length_);
auto cuda_stream = stream->As<ep::CudaStream>();
OF_CUDA_CHECK(hipMemsetAsync(n_evicted, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));
if (n_keys == 0) { return; }
cuda_stream->LaunchKernel(PutWithoutEvictingKernel<Key, Elem>, GetLaunchConfig(n_keys), ctx_,
n_keys, static_cast<const Key*>(keys),
static_cast<const Elem*>(values), n_evicted, query_keys_buffer_,
query_indices_buffer_);
cuda_stream->LaunchKernel(EvictKernel<Key, Elem>, GetLaunchConfig(n_keys), ctx_,
query_keys_buffer_, query_indices_buffer_,
static_cast<const Elem*>(values), n_evicted,
static_cast<Key*>(evicted_keys), static_cast<Elem*>(evicted_values));
}
void Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index,
uint32_t* n_dumped, void* keys, void* values) override {
auto cuda_stream = stream->As<ep::CudaStream>();
OF_CUDA_CHECK(hipMemsetAsync(n_dumped, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));
const uint64_t max_dump_keys = end_key_index - start_key_index;
cuda_stream->LaunchKernel(
DumpKernel<Key, Elem>,
ep::CudaLaunchConfig((max_dump_keys + kNumWarpPerBlock - 1) / kNumWarpPerBlock, kBlockSize,
0),
ctx_, start_key_index, end_key_index, n_dumped, static_cast<Key*>(keys),
static_cast<Elem*>(values));
}
void Clear() override { ClearLruCacheContext<Key, Elem>(&ctx_); }
private:
int device_index_;
uint32_t max_query_length_;
LruCacheContext<Key, Elem> ctx_;
uint32_t* query_indices_buffer_;
Key* query_keys_buffer_;
DataType value_type_;
};
template<typename Key>
std::unique_ptr<Cache> DispatchValueType(const CacheOptions& options) {
if (options.value_size % sizeof(ulonglong2) == 0) {
return std::unique_ptr<Cache>(new LruCache<Key, ulonglong2>(options));
} else if (options.value_size % sizeof(uint64_t) == 0) {
return std::unique_ptr<Cache>(new LruCache<Key, uint64_t>(options));
} else if (options.value_size % sizeof(uint32_t) == 0) {
return std::unique_ptr<Cache>(new LruCache<Key, uint32_t>(options));
} else if (options.value_size % sizeof(uint16_t) == 0) {
return std::unique_ptr<Cache>(new LruCache<Key, uint16_t>(options));
} else {
return std::unique_ptr<Cache>(new LruCache<Key, uint8_t>(options));
}
}
std::unique_ptr<Cache> DispatchKeyType(const CacheOptions& options) {
if (options.key_size == sizeof(uint32_t)) {
return DispatchValueType<uint32_t>(options);
} else if (options.key_size == sizeof(uint64_t)) {
return DispatchValueType<uint64_t>(options);
} else {
UNIMPLEMENTED();
return nullptr;
}
}
} // namespace
std::unique_ptr<Cache> NewLruCache(const CacheOptions& options) { return DispatchKeyType(options); }
} // namespace embedding
} // namespace oneflow
\ No newline at end of file
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/embedding/mock_key_value_store.h"
#include "oneflow/core/device/cuda_util.h"
namespace oneflow {
namespace embedding {
namespace {
template<typename Key>
class IteratorImpl : public KVIterator {
public:
OF_DISALLOW_COPY_AND_MOVE(IteratorImpl);
IteratorImpl(HashMap<Key, std::string>* store, uint32_t key_size, uint32_t value_size,
uint32_t max_query_length, void* host_keys_buffer, void* host_values_buffer,
uint32_t* host_num_buffer)
: store_(store),
pos_(store->begin()),
key_size_(key_size),
value_size_(value_size),
max_query_length_(max_query_length),
host_keys_buffer_(host_keys_buffer),
host_values_buffer_(host_values_buffer),
host_num_buffer_(host_num_buffer) {}
~IteratorImpl() override = default;
void NextN(ep::Stream* stream, uint32_t n_request, uint32_t* n_result, void* keys,
void* values) override {
CHECK_LE(n_request, max_query_length_);
auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_JUST(cuda_stream->Sync());
*host_num_buffer_ = 0;
while (*host_num_buffer_ < n_request && pos_ != store_->end()) {
reinterpret_cast<Key*>(host_keys_buffer_)[*host_num_buffer_] = pos_->first;
std::memcpy(reinterpret_cast<char*>(host_values_buffer_) + *host_num_buffer_ * value_size_,
pos_->second.data(), value_size_);
}
OF_CUDA_CHECK(hipMemcpyAsync(n_result, host_num_buffer_, sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream()));
const uint32_t num_keys = *host_num_buffer_;
if (num_keys != 0) {
OF_CUDA_CHECK(hipMemcpyAsync(keys, host_keys_buffer_, num_keys * key_size_,
hipMemcpyDefault, cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(values, host_values_buffer_, num_keys * value_size_,
hipMemcpyDefault, cuda_stream->cuda_stream()));
}
}
void Reset() override { pos_ = store_->begin(); }
private:
HashMap<Key, std::string>* store_;
typename HashMap<Key, std::string>::iterator pos_;
uint32_t key_size_;
uint32_t value_size_;
uint32_t max_query_length_;
void* host_keys_buffer_;
void* host_values_buffer_;
uint32_t* host_num_buffer_;
};
template<typename Key>
class KeyValueStoreImpl : public KeyValueStore {
public:
OF_DISALLOW_COPY_AND_MOVE(KeyValueStoreImpl);
explicit KeyValueStoreImpl(const MockKeyValueStoreOptions& options)
: device_index_(-1), max_query_length_(0) {
OF_CUDA_CHECK(hipGetDevice(&device_index_));
key_size_ = options.key_size;
value_size_ = options.value_size;
OF_CUDA_CHECK(NumaAwareCudaMallocHost(
device_index_, reinterpret_cast<void**>(&host_query_keys_), key_size_ * max_query_length_));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,
reinterpret_cast<void**>(&host_query_values_),
value_size_ * max_query_length_));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, reinterpret_cast<void**>(&host_n_missing_),
sizeof(uint32_t)));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,
reinterpret_cast<void**>(&host_missing_indices_),
sizeof(uint32_t) * max_query_length_));
}
~KeyValueStoreImpl() {
CudaCurrentDeviceGuard guard(device_index_);
if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipHostFree(host_query_keys_));
OF_CUDA_CHECK(hipHostFree(host_query_values_));
OF_CUDA_CHECK(hipHostFree(host_missing_indices_));
}
OF_CUDA_CHECK(hipHostFree(host_n_missing_));
}
uint32_t KeySize() const override { return key_size_; }
uint32_t ValueSize() const override { return value_size_; }
uint32_t MaxQueryLength() const override { return max_query_length_; }
void ReserveQueryLength(uint32_t query_length) override {
CudaCurrentDeviceGuard guard(device_index_);
if (query_length <= max_query_length_) { return; }
if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipHostFree(host_query_keys_));
OF_CUDA_CHECK(hipHostFree(host_query_values_));
OF_CUDA_CHECK(hipHostFree(host_missing_indices_));
}
OF_CUDA_CHECK(NumaAwareCudaMallocHost(
device_index_, reinterpret_cast<void**>(&host_query_keys_), key_size_ * query_length));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(
device_index_, reinterpret_cast<void**>(&host_query_values_), value_size_ * query_length));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,
reinterpret_cast<void**>(&host_missing_indices_),
sizeof(uint32_t) * query_length));
max_query_length_ = query_length;
}
void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values,
uint32_t* n_missing, uint32_t* missing_indices) override;
void Put(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values) override;
bool SnapshotExists(const std::string& name) override;
void LoadSnapshot(const std::string& name) override;
void LoadSnapshot(const std::string& name,
const std::function<void(KVIterator* iter)>& Hook) override;
void SaveSnapshot(const std::string& name) override;
private:
int device_index_;
uint32_t max_query_length_;
uint32_t key_size_;
uint32_t value_size_;
Key* host_query_keys_{};
uint8_t* host_query_values_{};
uint32_t* host_n_missing_{};
uint32_t* host_missing_indices_{};
HashMap<Key, std::string> store_;
HashMap<std::string, HashMap<Key, std::string>> snapshots_;
std::mutex mutex_;
};
template<typename Key>
void KeyValueStoreImpl<Key>::Get(ep::Stream* stream, uint32_t num_keys, const void* keys,
void* values, uint32_t* n_missing, uint32_t* missing_indices) {
std::lock_guard<std::mutex> lock(mutex_);
auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_LE(num_keys, max_query_length_);
if (num_keys == 0) {
OF_CUDA_CHECK(hipMemsetAsync(n_missing, 0, sizeof(uint32_t),
stream->As<ep::CudaStream>()->cuda_stream()));
return;
}
OF_CUDA_CHECK(hipMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, hipMemcpyDefault,
cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync());
*host_n_missing_ = 0;
for (uint32_t i = 0; i < num_keys; ++i) {
auto it = store_.find(host_query_keys_[i]);
if (it != store_.end()) {
std::memcpy(host_query_values_ + i * value_size_, it->second.data(), value_size_);
} else {
host_missing_indices_[*host_n_missing_] = i;
*host_n_missing_ += 1;
}
}
OF_CUDA_CHECK(hipMemcpyAsync(values, host_query_values_, num_keys * value_size_,
hipMemcpyDefault, cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(n_missing, host_n_missing_, sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(missing_indices, host_missing_indices_,
(*host_n_missing_) * sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream()));
}
template<typename Key>
void KeyValueStoreImpl<Key>::Put(ep::Stream* stream, uint32_t num_keys, const void* keys,
const void* values) {
std::lock_guard<std::mutex> lock(mutex_);
auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_LE(num_keys, max_query_length_);
if (num_keys == 0) { return; }
OF_CUDA_CHECK(hipMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, hipMemcpyDefault,
cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(host_query_values_, values, value_size_ * num_keys,
hipMemcpyDefault, cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync());
for (uint32_t i = 0; i < num_keys; ++i) {
store_[host_query_keys_[i]] = std::string(
reinterpret_cast<const char*>(host_query_values_) + i * value_size_, value_size_);
}
}
template<typename Key>
bool KeyValueStoreImpl<Key>::SnapshotExists(const std::string& name) {
return snapshots_.find(name) != snapshots_.end();
}
template<typename Key>
void KeyValueStoreImpl<Key>::LoadSnapshot(const std::string& name) {
CudaCurrentDeviceGuard guard(device_index_);
LoadSnapshot(name, nullptr);
}
template<typename Key>
void KeyValueStoreImpl<Key>::LoadSnapshot(const std::string& name,
const std::function<void(KVIterator* iter)>& Hook) {
CudaCurrentDeviceGuard guard(device_index_);
store_ = snapshots_[name];
if (Hook) {
IteratorImpl<Key> iterator(&store_, KeySize(), ValueSize(), max_query_length_, host_query_keys_,
host_query_values_, host_n_missing_);
Hook(&iterator);
}
}
template<typename Key>
void KeyValueStoreImpl<Key>::SaveSnapshot(const std::string& name) {
CudaCurrentDeviceGuard guard(device_index_);
snapshots_[name] = store_;
}
} // namespace
std::unique_ptr<KeyValueStore> NewMockKeyValueStore(const MockKeyValueStoreOptions& options) {
if (options.key_size == sizeof(uint64_t)) {
return std::unique_ptr<KeyValueStore>(new KeyValueStoreImpl<uint64_t>(options));
} else if (options.key_size == sizeof(uint32_t)) {
return std::unique_ptr<KeyValueStore>(new KeyValueStoreImpl<uint32_t>(options));
} else {
UNIMPLEMENTED();
return nullptr;
}
}
} // namespace embedding
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/embedding/mock_key_value_store.h"
#include "oneflow/core/device/cuda_util.h"
namespace oneflow {
namespace embedding {
namespace {
template<typename Key>
class IteratorImpl : public KVIterator {
public:
OF_DISALLOW_COPY_AND_MOVE(IteratorImpl);
IteratorImpl(HashMap<Key, std::string>* store, uint32_t key_size, uint32_t value_size,
uint32_t max_query_length, void* host_keys_buffer, void* host_values_buffer,
uint32_t* host_num_buffer)
: store_(store),
pos_(store->begin()),
key_size_(key_size),
value_size_(value_size),
max_query_length_(max_query_length),
host_keys_buffer_(host_keys_buffer),
host_values_buffer_(host_values_buffer),
host_num_buffer_(host_num_buffer) {}
~IteratorImpl() override = default;
void NextN(ep::Stream* stream, uint32_t n_request, uint32_t* n_result, void* keys,
void* values) override {
CHECK_LE(n_request, max_query_length_);
auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_JUST(cuda_stream->Sync());
*host_num_buffer_ = 0;
while (*host_num_buffer_ < n_request && pos_ != store_->end()) {
reinterpret_cast<Key*>(host_keys_buffer_)[*host_num_buffer_] = pos_->first;
std::memcpy(reinterpret_cast<char*>(host_values_buffer_) + *host_num_buffer_ * value_size_,
pos_->second.data(), value_size_);
}
OF_CUDA_CHECK(hipMemcpyAsync(n_result, host_num_buffer_, sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream()));
const uint32_t num_keys = *host_num_buffer_;
if (num_keys != 0) {
OF_CUDA_CHECK(hipMemcpyAsync(keys, host_keys_buffer_, num_keys * key_size_,
hipMemcpyDefault, cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(values, host_values_buffer_, num_keys * value_size_,
hipMemcpyDefault, cuda_stream->cuda_stream()));
}
}
void Reset() override { pos_ = store_->begin(); }
private:
HashMap<Key, std::string>* store_;
typename HashMap<Key, std::string>::iterator pos_;
uint32_t key_size_;
uint32_t value_size_;
uint32_t max_query_length_;
void* host_keys_buffer_;
void* host_values_buffer_;
uint32_t* host_num_buffer_;
};
template<typename Key>
class KeyValueStoreImpl : public KeyValueStore {
public:
OF_DISALLOW_COPY_AND_MOVE(KeyValueStoreImpl);
explicit KeyValueStoreImpl(const MockKeyValueStoreOptions& options)
: device_index_(-1), max_query_length_(0) {
OF_CUDA_CHECK(hipGetDevice(&device_index_));
key_size_ = options.key_size;
value_size_ = options.value_size;
OF_CUDA_CHECK(NumaAwareCudaMallocHost(
device_index_, reinterpret_cast<void**>(&host_query_keys_), key_size_ * max_query_length_));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,
reinterpret_cast<void**>(&host_query_values_),
value_size_ * max_query_length_));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, reinterpret_cast<void**>(&host_n_missing_),
sizeof(uint32_t)));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,
reinterpret_cast<void**>(&host_missing_indices_),
sizeof(uint32_t) * max_query_length_));
}
~KeyValueStoreImpl() {
CudaCurrentDeviceGuard guard(device_index_);
if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipHostFree(host_query_keys_));
OF_CUDA_CHECK(hipHostFree(host_query_values_));
OF_CUDA_CHECK(hipHostFree(host_missing_indices_));
}
OF_CUDA_CHECK(hipHostFree(host_n_missing_));
}
uint32_t KeySize() const override { return key_size_; }
uint32_t ValueSize() const override { return value_size_; }
uint32_t MaxQueryLength() const override { return max_query_length_; }
void ReserveQueryLength(uint32_t query_length) override {
CudaCurrentDeviceGuard guard(device_index_);
if (query_length <= max_query_length_) { return; }
if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipHostFree(host_query_keys_));
OF_CUDA_CHECK(hipHostFree(host_query_values_));
OF_CUDA_CHECK(hipHostFree(host_missing_indices_));
}
OF_CUDA_CHECK(NumaAwareCudaMallocHost(
device_index_, reinterpret_cast<void**>(&host_query_keys_), key_size_ * query_length));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(
device_index_, reinterpret_cast<void**>(&host_query_values_), value_size_ * query_length));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,
reinterpret_cast<void**>(&host_missing_indices_),
sizeof(uint32_t) * query_length));
max_query_length_ = query_length;
}
void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values,
uint32_t* n_missing, uint32_t* missing_indices) override;
void Put(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values) override;
bool SnapshotExists(const std::string& name) override;
void LoadSnapshot(const std::string& name) override;
void LoadSnapshot(const std::string& name,
const std::function<void(KVIterator* iter)>& Hook) override;
void SaveSnapshot(const std::string& name) override;
private:
int device_index_;
uint32_t max_query_length_;
uint32_t key_size_;
uint32_t value_size_;
Key* host_query_keys_{};
uint8_t* host_query_values_{};
uint32_t* host_n_missing_{};
uint32_t* host_missing_indices_{};
HashMap<Key, std::string> store_;
HashMap<std::string, HashMap<Key, std::string>> snapshots_;
std::mutex mutex_;
};
template<typename Key>
void KeyValueStoreImpl<Key>::Get(ep::Stream* stream, uint32_t num_keys, const void* keys,
void* values, uint32_t* n_missing, uint32_t* missing_indices) {
std::lock_guard<std::mutex> lock(mutex_);
auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_LE(num_keys, max_query_length_);
if (num_keys == 0) {
OF_CUDA_CHECK(hipMemsetAsync(n_missing, 0, sizeof(uint32_t),
stream->As<ep::CudaStream>()->cuda_stream()));
return;
}
OF_CUDA_CHECK(hipMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, hipMemcpyDefault,
cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync());
*host_n_missing_ = 0;
for (uint32_t i = 0; i < num_keys; ++i) {
auto it = store_.find(host_query_keys_[i]);
if (it != store_.end()) {
std::memcpy(host_query_values_ + i * value_size_, it->second.data(), value_size_);
} else {
host_missing_indices_[*host_n_missing_] = i;
*host_n_missing_ += 1;
}
}
OF_CUDA_CHECK(hipMemcpyAsync(values, host_query_values_, num_keys * value_size_,
hipMemcpyDefault, cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(n_missing, host_n_missing_, sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(missing_indices, host_missing_indices_,
(*host_n_missing_) * sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream()));
}
template<typename Key>
void KeyValueStoreImpl<Key>::Put(ep::Stream* stream, uint32_t num_keys, const void* keys,
const void* values) {
std::lock_guard<std::mutex> lock(mutex_);
auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_LE(num_keys, max_query_length_);
if (num_keys == 0) { return; }
OF_CUDA_CHECK(hipMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, hipMemcpyDefault,
cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(host_query_values_, values, value_size_ * num_keys,
hipMemcpyDefault, cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync());
for (uint32_t i = 0; i < num_keys; ++i) {
store_[host_query_keys_[i]] = std::string(
reinterpret_cast<const char*>(host_query_values_) + i * value_size_, value_size_);
}
}
template<typename Key>
bool KeyValueStoreImpl<Key>::SnapshotExists(const std::string& name) {
return snapshots_.find(name) != snapshots_.end();
}
template<typename Key>
void KeyValueStoreImpl<Key>::LoadSnapshot(const std::string& name) {
CudaCurrentDeviceGuard guard(device_index_);
LoadSnapshot(name, nullptr);
}
template<typename Key>
void KeyValueStoreImpl<Key>::LoadSnapshot(const std::string& name,
const std::function<void(KVIterator* iter)>& Hook) {
CudaCurrentDeviceGuard guard(device_index_);
store_ = snapshots_[name];
if (Hook) {
IteratorImpl<Key> iterator(&store_, KeySize(), ValueSize(), max_query_length_, host_query_keys_,
host_query_values_, host_n_missing_);
Hook(&iterator);
}
}
template<typename Key>
void KeyValueStoreImpl<Key>::SaveSnapshot(const std::string& name) {
CudaCurrentDeviceGuard guard(device_index_);
snapshots_[name] = store_;
}
} // namespace
std::unique_ptr<KeyValueStore> NewMockKeyValueStore(const MockKeyValueStoreOptions& options) {
if (options.key_size == sizeof(uint64_t)) {
return std::unique_ptr<KeyValueStore>(new KeyValueStoreImpl<uint64_t>(options));
} else if (options.key_size == sizeof(uint32_t)) {
return std::unique_ptr<KeyValueStore>(new KeyValueStoreImpl<uint32_t>(options));
} else {
UNIMPLEMENTED();
return nullptr;
}
}
} // namespace embedding
} // namespace oneflow
\ No newline at end of file
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/embedding/persistent_table_key_value_store.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/embedding/persistent_table.h"
#include <robin_hood.h>
#include <fcntl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <dirent.h>
namespace oneflow {
namespace embedding {
namespace {
class IteratorImpl : public KVIterator {
public:
OF_DISALLOW_COPY_AND_MOVE(IteratorImpl);
IteratorImpl(PersistentTable::Iterator* base_iter, uint32_t key_size, uint32_t value_size,
uint32_t max_query_length, void* host_keys_buffer, void* host_values_buffer,
uint32_t* host_num_buffer)
: base_iter_(base_iter),
key_size_(key_size),
value_size_(value_size),
max_query_length_(max_query_length),
host_keys_buffer_(host_keys_buffer),
host_values_buffer_(host_values_buffer),
host_num_buffer_(host_num_buffer) {}
~IteratorImpl() override = default;
void NextN(ep::Stream* stream, uint32_t n_request, uint32_t* n_result, void* keys,
void* values) override {
CHECK_LE(n_request, max_query_length_);
auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_JUST(cuda_stream->Sync());
base_iter_->Next(n_request, host_num_buffer_, host_keys_buffer_, host_values_buffer_);
OF_CUDA_CHECK(hipMemcpyAsync(n_result, host_num_buffer_, sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream()));
const uint32_t num_keys = *host_num_buffer_;
if (num_keys != 0) {
OF_CUDA_CHECK(hipMemcpyAsync(keys, host_keys_buffer_, num_keys * key_size_,
hipMemcpyDefault, cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(values, host_values_buffer_, num_keys * value_size_,
hipMemcpyDefault, cuda_stream->cuda_stream()));
}
}
void Reset() override { base_iter_->Reset(); }
private:
PersistentTable::Iterator* base_iter_;
uint32_t key_size_;
uint32_t value_size_;
uint32_t max_query_length_;
void* host_keys_buffer_;
void* host_values_buffer_;
uint32_t* host_num_buffer_;
};
template<typename Key>
class KeyValueStoreImpl : public KeyValueStore {
public:
OF_DISALLOW_COPY_AND_MOVE(KeyValueStoreImpl);
explicit KeyValueStoreImpl(const PersistentTableKeyValueStoreOptions& options)
: device_index_(-1), max_query_length_(0) {
OF_CUDA_CHECK(hipGetDevice(&device_index_));
key_size_ = options.table_options.key_size;
value_size_ = options.table_options.value_size;
table_ = NewPersistentTable(options.table_options);
OF_CUDA_CHECK(NumaAwareCudaMallocHost(
device_index_, reinterpret_cast<void**>(&host_query_keys_), key_size_ * max_query_length_));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,
reinterpret_cast<void**>(&host_query_values_),
value_size_ * max_query_length_));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, reinterpret_cast<void**>(&host_n_missing_),
sizeof(uint32_t)));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,
reinterpret_cast<void**>(&host_missing_indices_),
sizeof(uint32_t) * max_query_length_));
}
~KeyValueStoreImpl() {
CudaCurrentDeviceGuard guard(device_index_);
if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipHostFree(host_query_keys_));
OF_CUDA_CHECK(hipHostFree(host_query_values_));
OF_CUDA_CHECK(hipHostFree(host_missing_indices_));
}
OF_CUDA_CHECK(hipHostFree(host_n_missing_));
}
uint32_t KeySize() const override { return key_size_; }
uint32_t ValueSize() const override { return value_size_; }
uint32_t MaxQueryLength() const override { return max_query_length_; }
void ReserveQueryLength(uint32_t query_length) override {
CudaCurrentDeviceGuard guard(device_index_);
if (query_length <= max_query_length_) { return; }
if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipHostFree(host_query_keys_));
OF_CUDA_CHECK(hipHostFree(host_query_values_));
OF_CUDA_CHECK(hipHostFree(host_missing_indices_));
}
OF_CUDA_CHECK(NumaAwareCudaMallocHost(
device_index_, reinterpret_cast<void**>(&host_query_keys_), key_size_ * query_length));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(
device_index_, reinterpret_cast<void**>(&host_query_values_), value_size_ * query_length));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,
reinterpret_cast<void**>(&host_missing_indices_),
sizeof(uint32_t) * query_length));
max_query_length_ = query_length;
}
void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values,
uint32_t* n_missing, uint32_t* missing_indices) override;
void Put(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values) override;
bool SnapshotExists(const std::string& name) override;
void LoadSnapshot(const std::string& name) override;
void LoadSnapshot(const std::string& name,
const std::function<void(KVIterator* iter)>& Hook) override;
void SaveSnapshot(const std::string& name) override;
private:
int device_index_;
uint32_t max_query_length_;
uint32_t key_size_;
uint32_t value_size_;
Key* host_query_keys_{};
uint8_t* host_query_values_{};
uint32_t* host_n_missing_{};
uint32_t* host_missing_indices_{};
std::mutex mutex_;
std::unique_ptr<PersistentTable> table_;
};
template<typename Key>
void KeyValueStoreImpl<Key>::Get(ep::Stream* stream, uint32_t num_keys, const void* keys,
void* values, uint32_t* n_missing, uint32_t* missing_indices) {
std::lock_guard<std::mutex> lock(mutex_);
auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_LE(num_keys, max_query_length_);
if (num_keys == 0) {
OF_CUDA_CHECK(hipMemsetAsync(n_missing, 0, sizeof(uint32_t),
stream->As<ep::CudaStream>()->cuda_stream()));
return;
}
OF_CUDA_CHECK(hipMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, hipMemcpyDefault,
cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync());
table_->Get(num_keys, host_query_keys_, host_query_values_, host_n_missing_,
host_missing_indices_);
OF_CUDA_CHECK(hipMemcpyAsync(values, host_query_values_, num_keys * value_size_,
hipMemcpyDefault, cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(n_missing, host_n_missing_, sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(missing_indices, host_missing_indices_,
(*host_n_missing_) * sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream()));
}
template<typename Key>
void KeyValueStoreImpl<Key>::Put(ep::Stream* stream, uint32_t num_keys, const void* keys,
const void* values) {
std::lock_guard<std::mutex> lock(mutex_);
auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_LE(num_keys, max_query_length_);
if (num_keys == 0) { return; }
OF_CUDA_CHECK(hipMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, hipMemcpyDefault,
cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(host_query_values_, values, value_size_ * num_keys,
hipMemcpyDefault, cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync());
table_->Put(num_keys, host_query_keys_, host_query_values_);
}
template<typename Key>
bool KeyValueStoreImpl<Key>::SnapshotExists(const std::string& name) {
return table_->SnapshotExists(name);
}
template<typename Key>
void KeyValueStoreImpl<Key>::LoadSnapshot(const std::string& name) {
CudaCurrentDeviceGuard guard(device_index_);
LoadSnapshot(name, nullptr);
}
template<typename Key>
void KeyValueStoreImpl<Key>::LoadSnapshot(const std::string& name,
const std::function<void(KVIterator* iter)>& Hook) {
CudaCurrentDeviceGuard guard(device_index_);
if (Hook) {
table_->LoadSnapshot(name, [&](PersistentTable::Iterator* chunk_iterator) {
IteratorImpl iterator(chunk_iterator, KeySize(), ValueSize(), max_query_length_,
host_query_keys_, host_query_values_, host_n_missing_);
Hook(&iterator);
});
} else {
table_->LoadSnapshot(name);
}
}
template<typename Key>
void KeyValueStoreImpl<Key>::SaveSnapshot(const std::string& name) {
CudaCurrentDeviceGuard guard(device_index_);
table_->SaveSnapshot(name);
}
} // namespace
std::unique_ptr<KeyValueStore> NewPersistentTableKeyValueStore(
const PersistentTableKeyValueStoreOptions& options) {
if (options.table_options.key_size == sizeof(uint64_t)) {
return std::unique_ptr<KeyValueStore>(new KeyValueStoreImpl<uint64_t>(options));
} else if (options.table_options.key_size == sizeof(uint32_t)) {
return std::unique_ptr<KeyValueStore>(new KeyValueStoreImpl<uint32_t>(options));
} else {
UNIMPLEMENTED();
return nullptr;
}
}
} // namespace embedding
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/embedding/persistent_table_key_value_store.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/embedding/persistent_table.h"
#include <robin_hood.h>
#include <fcntl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <dirent.h>
namespace oneflow {
namespace embedding {
namespace {
class IteratorImpl : public KVIterator {
public:
OF_DISALLOW_COPY_AND_MOVE(IteratorImpl);
IteratorImpl(PersistentTable::Iterator* base_iter, uint32_t key_size, uint32_t value_size,
uint32_t max_query_length, void* host_keys_buffer, void* host_values_buffer,
uint32_t* host_num_buffer)
: base_iter_(base_iter),
key_size_(key_size),
value_size_(value_size),
max_query_length_(max_query_length),
host_keys_buffer_(host_keys_buffer),
host_values_buffer_(host_values_buffer),
host_num_buffer_(host_num_buffer) {}
~IteratorImpl() override = default;
void NextN(ep::Stream* stream, uint32_t n_request, uint32_t* n_result, void* keys,
void* values) override {
CHECK_LE(n_request, max_query_length_);
auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_JUST(cuda_stream->Sync());
base_iter_->Next(n_request, host_num_buffer_, host_keys_buffer_, host_values_buffer_);
OF_CUDA_CHECK(hipMemcpyAsync(n_result, host_num_buffer_, sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream()));
const uint32_t num_keys = *host_num_buffer_;
if (num_keys != 0) {
OF_CUDA_CHECK(hipMemcpyAsync(keys, host_keys_buffer_, num_keys * key_size_,
hipMemcpyDefault, cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(values, host_values_buffer_, num_keys * value_size_,
hipMemcpyDefault, cuda_stream->cuda_stream()));
}
}
void Reset() override { base_iter_->Reset(); }
private:
PersistentTable::Iterator* base_iter_;
uint32_t key_size_;
uint32_t value_size_;
uint32_t max_query_length_;
void* host_keys_buffer_;
void* host_values_buffer_;
uint32_t* host_num_buffer_;
};
template<typename Key>
class KeyValueStoreImpl : public KeyValueStore {
public:
OF_DISALLOW_COPY_AND_MOVE(KeyValueStoreImpl);
explicit KeyValueStoreImpl(const PersistentTableKeyValueStoreOptions& options)
: device_index_(-1), max_query_length_(0) {
OF_CUDA_CHECK(hipGetDevice(&device_index_));
key_size_ = options.table_options.key_size;
value_size_ = options.table_options.value_size;
table_ = NewPersistentTable(options.table_options);
OF_CUDA_CHECK(NumaAwareCudaMallocHost(
device_index_, reinterpret_cast<void**>(&host_query_keys_), key_size_ * max_query_length_));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,
reinterpret_cast<void**>(&host_query_values_),
value_size_ * max_query_length_));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, reinterpret_cast<void**>(&host_n_missing_),
sizeof(uint32_t)));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,
reinterpret_cast<void**>(&host_missing_indices_),
sizeof(uint32_t) * max_query_length_));
}
~KeyValueStoreImpl() {
CudaCurrentDeviceGuard guard(device_index_);
if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipHostFree(host_query_keys_));
OF_CUDA_CHECK(hipHostFree(host_query_values_));
OF_CUDA_CHECK(hipHostFree(host_missing_indices_));
}
OF_CUDA_CHECK(hipHostFree(host_n_missing_));
}
uint32_t KeySize() const override { return key_size_; }
uint32_t ValueSize() const override { return value_size_; }
uint32_t MaxQueryLength() const override { return max_query_length_; }
void ReserveQueryLength(uint32_t query_length) override {
CudaCurrentDeviceGuard guard(device_index_);
if (query_length <= max_query_length_) { return; }
if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipHostFree(host_query_keys_));
OF_CUDA_CHECK(hipHostFree(host_query_values_));
OF_CUDA_CHECK(hipHostFree(host_missing_indices_));
}
OF_CUDA_CHECK(NumaAwareCudaMallocHost(
device_index_, reinterpret_cast<void**>(&host_query_keys_), key_size_ * query_length));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(
device_index_, reinterpret_cast<void**>(&host_query_values_), value_size_ * query_length));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,
reinterpret_cast<void**>(&host_missing_indices_),
sizeof(uint32_t) * query_length));
max_query_length_ = query_length;
}
void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values,
uint32_t* n_missing, uint32_t* missing_indices) override;
void Put(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values) override;
bool SnapshotExists(const std::string& name) override;
void LoadSnapshot(const std::string& name) override;
void LoadSnapshot(const std::string& name,
const std::function<void(KVIterator* iter)>& Hook) override;
void SaveSnapshot(const std::string& name) override;
private:
int device_index_;
uint32_t max_query_length_;
uint32_t key_size_;
uint32_t value_size_;
Key* host_query_keys_{};
uint8_t* host_query_values_{};
uint32_t* host_n_missing_{};
uint32_t* host_missing_indices_{};
std::mutex mutex_;
std::unique_ptr<PersistentTable> table_;
};
template<typename Key>
void KeyValueStoreImpl<Key>::Get(ep::Stream* stream, uint32_t num_keys, const void* keys,
void* values, uint32_t* n_missing, uint32_t* missing_indices) {
std::lock_guard<std::mutex> lock(mutex_);
auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_LE(num_keys, max_query_length_);
if (num_keys == 0) {
OF_CUDA_CHECK(hipMemsetAsync(n_missing, 0, sizeof(uint32_t),
stream->As<ep::CudaStream>()->cuda_stream()));
return;
}
OF_CUDA_CHECK(hipMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, hipMemcpyDefault,
cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync());
table_->Get(num_keys, host_query_keys_, host_query_values_, host_n_missing_,
host_missing_indices_);
OF_CUDA_CHECK(hipMemcpyAsync(values, host_query_values_, num_keys * value_size_,
hipMemcpyDefault, cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(n_missing, host_n_missing_, sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(missing_indices, host_missing_indices_,
(*host_n_missing_) * sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream()));
}
template<typename Key>
void KeyValueStoreImpl<Key>::Put(ep::Stream* stream, uint32_t num_keys, const void* keys,
const void* values) {
std::lock_guard<std::mutex> lock(mutex_);
auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_LE(num_keys, max_query_length_);
if (num_keys == 0) { return; }
OF_CUDA_CHECK(hipMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, hipMemcpyDefault,
cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(host_query_values_, values, value_size_ * num_keys,
hipMemcpyDefault, cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync());
table_->Put(num_keys, host_query_keys_, host_query_values_);
}
template<typename Key>
bool KeyValueStoreImpl<Key>::SnapshotExists(const std::string& name) {
return table_->SnapshotExists(name);
}
template<typename Key>
void KeyValueStoreImpl<Key>::LoadSnapshot(const std::string& name) {
CudaCurrentDeviceGuard guard(device_index_);
LoadSnapshot(name, nullptr);
}
template<typename Key>
void KeyValueStoreImpl<Key>::LoadSnapshot(const std::string& name,
const std::function<void(KVIterator* iter)>& Hook) {
CudaCurrentDeviceGuard guard(device_index_);
if (Hook) {
table_->LoadSnapshot(name, [&](PersistentTable::Iterator* chunk_iterator) {
IteratorImpl iterator(chunk_iterator, KeySize(), ValueSize(), max_query_length_,
host_query_keys_, host_query_values_, host_n_missing_);
Hook(&iterator);
});
} else {
table_->LoadSnapshot(name);
}
}
template<typename Key>
void KeyValueStoreImpl<Key>::SaveSnapshot(const std::string& name) {
CudaCurrentDeviceGuard guard(device_index_);
table_->SaveSnapshot(name);
}
} // namespace
std::unique_ptr<KeyValueStore> NewPersistentTableKeyValueStore(
const PersistentTableKeyValueStoreOptions& options) {
if (options.table_options.key_size == sizeof(uint64_t)) {
return std::unique_ptr<KeyValueStore>(new KeyValueStoreImpl<uint64_t>(options));
} else if (options.table_options.key_size == sizeof(uint32_t)) {
return std::unique_ptr<KeyValueStore>(new KeyValueStoreImpl<uint32_t>(options));
} else {
UNIMPLEMENTED();
return nullptr;
}
}
} // namespace embedding
} // namespace oneflow
\ No newline at end of file
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/rocm/cuda_device.h"
#include "oneflow/core/ep/rocm/cuda_event.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
// #if CUDA_VERSION >= 11000
// #include <cuda_bf16.h>
// #endif
namespace oneflow {
namespace ep {
namespace {
constexpr size_t kDefaultConstBufElementCount = 1024 * 1024;
template<typename T>
void CreateConstBuffer(void** buf, T value, size_t n) {
OF_CUDA_CHECK(hipMalloc(buf, n * sizeof(T)));
std::vector<T> host(n, value);
OF_CUDA_CHECK(hipMemcpy(*buf, host.data(), n * sizeof(T), hipMemcpyDefault));
}
} // namespace
CudaDevice::CudaDevice(int device_index, DeviceManager* device_manager)
: device_index_(device_index),
event_flags_{},
properties_{},
device_manager_(device_manager),
const_buf_elem_cnt_(0),
const_zeros_buffer_(nullptr),
const_ones_buffer_fp32_(nullptr),
const_ones_buffer_fp16_(nullptr),
const_ones_buffer_bf16_(nullptr) {
CudaCurrentDeviceGuard guard(device_index_);
OF_CUDA_CHECK(hipGetDeviceProperties(&properties_, device_index_));
event_flags_ = hipEventDisableTiming;
if (ParseBooleanFromEnv("ONEFLOW_STREAM_CUDA_EVENT_FLAG_BLOCKING_SYNC", false)) {
event_flags_ |= hipEventBlockingSync;
}
const_buf_elem_cnt_ = ParseIntegerFromEnv("ONEFLOW_EP_CUDA_CONST_BUFFER_ELEMENT_COUNT",
kDefaultConstBufElementCount);
if (const_buf_elem_cnt_ > 0) {
CreateConstBuffer<float>(&const_zeros_buffer_, static_cast<float>(0), const_buf_elem_cnt_);
CreateConstBuffer<float>(&const_ones_buffer_fp32_, static_cast<float>(1.0),
const_buf_elem_cnt_);
CreateConstBuffer<half>(&const_ones_buffer_fp16_, static_cast<half>(1.0), const_buf_elem_cnt_);
// #if CUDA_VERSION >= 11000
// CreateConstBuffer<nv_bfloat16>(&const_ones_buffer_bf16_, static_cast<nv_bfloat16>(1.0),
// const_buf_elem_cnt_);
// #endif
}
}
CudaDevice::~CudaDevice() {
CudaCurrentDeviceGuard guard(device_index_);
for (auto* event : events_) { delete event; }
OF_CUDA_CHECK(hipFree(const_zeros_buffer_));
OF_CUDA_CHECK(hipFree(const_ones_buffer_fp32_));
OF_CUDA_CHECK(hipFree(const_ones_buffer_fp16_));
OF_CUDA_CHECK(hipFree(const_ones_buffer_bf16_));
}
void CudaDevice::SetAsActiveDevice() { OF_CUDA_CHECK(hipSetDevice(device_index_)); }
Stream* CudaDevice::CreateStream() {
CudaCurrentDeviceGuard guard(device_index_);
return new CudaStream(this);
}
void CudaDevice::DestroyStream(Stream* stream) {
CudaCurrentDeviceGuard guard(device_index_);
delete stream;
}
void CudaDevice::CreateEvents(Event** events, size_t count) {
size_t copied = 0;
{
std::lock_guard<std::mutex> lock(events_mutex_);
copied = std::min(count, events_.size());
size_t offset = events_.size() - copied;
std::copy(events_.begin() + offset, events_.end(), events);
events_.resize(offset);
}
if (copied != count) {
CudaCurrentDeviceGuard guard(device_index_);
for (size_t i = copied; i < count; ++i) { events[i] = new CudaEvent(event_flags_); }
}
}
void CudaDevice::DestroyEvents(Event** events, size_t count) {
std::lock_guard<std::mutex> lock(events_mutex_);
events_.insert(events_.end(), events, events + count);
}
Maybe<void> CudaDevice::Alloc(const AllocationOptions& options, void** ptr, size_t size) {
CudaCurrentDeviceGuard guard(device_index_);
CHECK(!options.HasPinnedDevice());
hipError_t err = hipMalloc(ptr, size);
if (err != hipSuccess) {
return Error::RuntimeError() << hipGetErrorString(err);
} else {
return Maybe<void>::Ok();
}
}
void CudaDevice::Free(const AllocationOptions& attr, void* ptr) {
CudaCurrentDeviceGuard guard(device_index_);
OF_CUDA_CHECK(hipFree(ptr));
}
Maybe<void> CudaDevice::AllocPinned(const AllocationOptions& options, void** ptr, size_t size) {
CudaCurrentDeviceGuard guard(device_index_);
hipError_t err = NumaAwareCudaMallocHost(device_index_, ptr, size);
if (err != hipSuccess) {
return Error::RuntimeError() << hipGetErrorString(err);
} else {
return Maybe<void>::Ok();
}
}
void CudaDevice::FreePinned(const AllocationOptions& options, void* ptr) {
CudaCurrentDeviceGuard guard(device_index_);
OF_CUDA_CHECK(hipHostFree(ptr));
}
const hipDeviceProp_t& CudaDevice::properties() const { return properties_; }
const void* CudaDevice::GetConstZeros(DataType data_type, size_t n) const {
if (GetSizeOfDataType(data_type) * n
<= GetSizeOfDataType(DataType::kFloat) * const_buf_elem_cnt_) {
return const_zeros_buffer_;
} else {
return nullptr;
}
}
const void* CudaDevice::GetConstOnes(DataType data_type, size_t n) const {
if (n <= const_buf_elem_cnt_) {
if (data_type == DataType::kFloat) {
return const_ones_buffer_fp32_;
} else if (data_type == DataType::kFloat16) {
return const_ones_buffer_fp16_;
} else if (data_type == DataType::kBFloat16) {
return const_ones_buffer_bf16_;
} else {
return nullptr;
}
} else {
return nullptr;
}
}
} // namespace ep
} // namespace oneflow
#endif // WITH_ROCM
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/rocm/cuda_device.h"
#include "oneflow/core/ep/rocm/cuda_event.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
// #if CUDA_VERSION >= 11000
// #include <cuda_bf16.h>
// #endif
namespace oneflow {
namespace ep {
namespace {
constexpr size_t kDefaultConstBufElementCount = 1024 * 1024;
template<typename T>
void CreateConstBuffer(void** buf, T value, size_t n) {
OF_CUDA_CHECK(hipMalloc(buf, n * sizeof(T)));
std::vector<T> host(n, value);
OF_CUDA_CHECK(hipMemcpy(*buf, host.data(), n * sizeof(T), hipMemcpyDefault));
}
} // namespace
CudaDevice::CudaDevice(int device_index, DeviceManager* device_manager)
: device_index_(device_index),
event_flags_{},
properties_{},
device_manager_(device_manager),
const_buf_elem_cnt_(0),
const_zeros_buffer_(nullptr),
const_ones_buffer_fp32_(nullptr),
const_ones_buffer_fp16_(nullptr),
const_ones_buffer_bf16_(nullptr) {
CudaCurrentDeviceGuard guard(device_index_);
OF_CUDA_CHECK(hipGetDeviceProperties(&properties_, device_index_));
event_flags_ = hipEventDisableTiming;
if (ParseBooleanFromEnv("ONEFLOW_STREAM_CUDA_EVENT_FLAG_BLOCKING_SYNC", false)) {
event_flags_ |= hipEventBlockingSync;
}
const_buf_elem_cnt_ = ParseIntegerFromEnv("ONEFLOW_EP_CUDA_CONST_BUFFER_ELEMENT_COUNT",
kDefaultConstBufElementCount);
if (const_buf_elem_cnt_ > 0) {
CreateConstBuffer<float>(&const_zeros_buffer_, static_cast<float>(0), const_buf_elem_cnt_);
CreateConstBuffer<float>(&const_ones_buffer_fp32_, static_cast<float>(1.0),
const_buf_elem_cnt_);
CreateConstBuffer<half>(&const_ones_buffer_fp16_, static_cast<half>(1.0), const_buf_elem_cnt_);
// #if CUDA_VERSION >= 11000
// CreateConstBuffer<nv_bfloat16>(&const_ones_buffer_bf16_, static_cast<nv_bfloat16>(1.0),
// const_buf_elem_cnt_);
// #endif
}
}
CudaDevice::~CudaDevice() {
CudaCurrentDeviceGuard guard(device_index_);
for (auto* event : events_) { delete event; }
OF_CUDA_CHECK(hipFree(const_zeros_buffer_));
OF_CUDA_CHECK(hipFree(const_ones_buffer_fp32_));
OF_CUDA_CHECK(hipFree(const_ones_buffer_fp16_));
OF_CUDA_CHECK(hipFree(const_ones_buffer_bf16_));
}
void CudaDevice::SetAsActiveDevice() { OF_CUDA_CHECK(hipSetDevice(device_index_)); }
Stream* CudaDevice::CreateStream() {
CudaCurrentDeviceGuard guard(device_index_);
return new CudaStream(this);
}
void CudaDevice::DestroyStream(Stream* stream) {
CudaCurrentDeviceGuard guard(device_index_);
delete stream;
}
void CudaDevice::CreateEvents(Event** events, size_t count) {
size_t copied = 0;
{
std::lock_guard<std::mutex> lock(events_mutex_);
copied = std::min(count, events_.size());
size_t offset = events_.size() - copied;
std::copy(events_.begin() + offset, events_.end(), events);
events_.resize(offset);
}
if (copied != count) {
CudaCurrentDeviceGuard guard(device_index_);
for (size_t i = copied; i < count; ++i) { events[i] = new CudaEvent(event_flags_); }
}
}
void CudaDevice::DestroyEvents(Event** events, size_t count) {
std::lock_guard<std::mutex> lock(events_mutex_);
events_.insert(events_.end(), events, events + count);
}
Maybe<void> CudaDevice::Alloc(const AllocationOptions& options, void** ptr, size_t size) {
CudaCurrentDeviceGuard guard(device_index_);
CHECK(!options.HasPinnedDevice());
hipError_t err = hipMalloc(ptr, size);
if (err != hipSuccess) {
return Error::RuntimeError() << hipGetErrorString(err);
} else {
return Maybe<void>::Ok();
}
}
void CudaDevice::Free(const AllocationOptions& attr, void* ptr) {
CudaCurrentDeviceGuard guard(device_index_);
OF_CUDA_CHECK(hipFree(ptr));
}
Maybe<void> CudaDevice::AllocPinned(const AllocationOptions& options, void** ptr, size_t size) {
CudaCurrentDeviceGuard guard(device_index_);
hipError_t err = NumaAwareCudaMallocHost(device_index_, ptr, size);
if (err != hipSuccess) {
return Error::RuntimeError() << hipGetErrorString(err);
} else {
return Maybe<void>::Ok();
}
}
void CudaDevice::FreePinned(const AllocationOptions& options, void* ptr) {
CudaCurrentDeviceGuard guard(device_index_);
OF_CUDA_CHECK(hipHostFree(ptr));
}
const hipDeviceProp_t& CudaDevice::properties() const { return properties_; }
const void* CudaDevice::GetConstZeros(DataType data_type, size_t n) const {
if (GetSizeOfDataType(data_type) * n
<= GetSizeOfDataType(DataType::kFloat) * const_buf_elem_cnt_) {
return const_zeros_buffer_;
} else {
return nullptr;
}
}
const void* CudaDevice::GetConstOnes(DataType data_type, size_t n) const {
if (n <= const_buf_elem_cnt_) {
if (data_type == DataType::kFloat) {
return const_ones_buffer_fp32_;
} else if (data_type == DataType::kFloat16) {
return const_ones_buffer_fp16_;
} else if (data_type == DataType::kBFloat16) {
return const_ones_buffer_bf16_;
} else {
return nullptr;
}
} else {
return nullptr;
}
}
} // namespace ep
} // namespace oneflow
#endif // WITH_ROCM
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EP_ROCM_CUDA_DEVICE_H_
#define ONEFLOW_CORE_EP_ROCM_CUDA_DEVICE_H_
#include "oneflow/core/ep/include/device.h"
#include "oneflow/core/common/data_type.h"
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
namespace oneflow {
namespace ep {
class CudaDevice : public Device {
public:
OF_DISALLOW_COPY_AND_MOVE(CudaDevice);
explicit CudaDevice(int device_index, DeviceManager* device_manager);
~CudaDevice() override;
void SetAsActiveDevice() override;
DeviceType device_type() const override { return DeviceType::kCUDA; }
size_t device_index() const override { return device_index_; }
DeviceManager* device_manager() const override { return device_manager_; }
Stream* CreateStream() override;
void DestroyStream(Stream* stream) override;
void CreateEvents(Event** events, size_t count) override;
void DestroyEvents(Event** events, size_t count) override;
Maybe<void> Alloc(const AllocationOptions& options, void** ptr, size_t size) override;
void Free(const AllocationOptions& options, void* ptr) override;
Maybe<void> AllocPinned(const AllocationOptions& options, void** ptr, size_t size) override;
void FreePinned(const AllocationOptions& options, void* ptr) override;
const hipDeviceProp_t& properties() const;
const void* GetConstZeros(DataType data_type, size_t n) const;
const void* GetConstOnes(DataType data_type, size_t n) const;
private:
int device_index_;
std::mutex events_mutex_;
std::vector<Event*> events_;
unsigned int event_flags_;
hipDeviceProp_t properties_;
DeviceManager* device_manager_;
int64_t const_buf_elem_cnt_;
void* const_zeros_buffer_;
void* const_ones_buffer_fp32_;
void* const_ones_buffer_fp16_;
void* const_ones_buffer_bf16_;
};
} // namespace ep
} // namespace oneflow
#endif // WITH_ROCM
#endif // ONEFLOW_CORE_EP_ROCM_CUDA_DEVICE_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EP_ROCM_CUDA_DEVICE_H_
#define ONEFLOW_CORE_EP_ROCM_CUDA_DEVICE_H_
#include "oneflow/core/ep/include/device.h"
#include "oneflow/core/common/data_type.h"
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
namespace oneflow {
namespace ep {
class CudaDevice : public Device {
public:
OF_DISALLOW_COPY_AND_MOVE(CudaDevice);
explicit CudaDevice(int device_index, DeviceManager* device_manager);
~CudaDevice() override;
void SetAsActiveDevice() override;
DeviceType device_type() const override { return DeviceType::kCUDA; }
size_t device_index() const override { return device_index_; }
DeviceManager* device_manager() const override { return device_manager_; }
Stream* CreateStream() override;
void DestroyStream(Stream* stream) override;
void CreateEvents(Event** events, size_t count) override;
void DestroyEvents(Event** events, size_t count) override;
Maybe<void> Alloc(const AllocationOptions& options, void** ptr, size_t size) override;
void Free(const AllocationOptions& options, void* ptr) override;
Maybe<void> AllocPinned(const AllocationOptions& options, void** ptr, size_t size) override;
void FreePinned(const AllocationOptions& options, void* ptr) override;
const hipDeviceProp_t& properties() const;
const void* GetConstZeros(DataType data_type, size_t n) const;
const void* GetConstOnes(DataType data_type, size_t n) const;
private:
int device_index_;
std::mutex events_mutex_;
std::vector<Event*> events_;
unsigned int event_flags_;
hipDeviceProp_t properties_;
DeviceManager* device_manager_;
int64_t const_buf_elem_cnt_;
void* const_zeros_buffer_;
void* const_ones_buffer_fp32_;
void* const_ones_buffer_fp16_;
void* const_ones_buffer_bf16_;
};
} // namespace ep
} // namespace oneflow
#endif // WITH_ROCM
#endif // ONEFLOW_CORE_EP_ROCM_CUDA_DEVICE_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/rocm/cuda_device_manager.h"
#include "oneflow/core/device/cuda_util.h"
#ifdef WITH_ROCM
namespace oneflow {
namespace ep {
CudaDeviceManager::CudaDeviceManager(DeviceManagerRegistry* registry) : registry_(registry) {}
CudaDeviceManager::~CudaDeviceManager() = default;
DeviceManagerRegistry* CudaDeviceManager::registry() const { return registry_; }
std::shared_ptr<Device> CudaDeviceManager::GetDevice(size_t device_index) {
std::lock_guard<std::mutex> lock(devices_mutex_);
if (device_index < devices_.size() && devices_.at(device_index)) {
return devices_.at(device_index);
}
auto device = std::make_shared<CudaDevice>(device_index, this);
if (device_index >= devices_.size()) { devices_.resize(device_index + 1); }
devices_.at(device_index) = device;
return device;
}
size_t CudaDeviceManager::GetDeviceCount(size_t primary_device_index) {
CudaCurrentDeviceGuard guard(primary_device_index);
return this->GetDeviceCount();
}
size_t CudaDeviceManager::GetDeviceCount() {
int count = 0;
hipError_t err = hipGetDeviceCount(&count);
if (err == hipErrorNoDevice || err == hipErrorInsufficientDriver) { return 0; }
OF_CUDA_CHECK(err);
return count;
}
size_t CudaDeviceManager::GetActiveDeviceIndex() {
int device = 0;
OF_CUDA_CHECK(hipGetDevice(&device));
return static_cast<size_t>(device);
}
void CudaDeviceManager::SetActiveDeviceByIndex(size_t device_index) {
OF_CUDA_CHECK(hipSetDevice(static_cast<int>(device_index)));
}
} // namespace ep
} // namespace oneflow
#endif // WITH_ROCM
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/rocm/cuda_device_manager.h"
#include "oneflow/core/device/cuda_util.h"
#ifdef WITH_ROCM
namespace oneflow {
namespace ep {
CudaDeviceManager::CudaDeviceManager(DeviceManagerRegistry* registry) : registry_(registry) {}
CudaDeviceManager::~CudaDeviceManager() = default;
DeviceManagerRegistry* CudaDeviceManager::registry() const { return registry_; }
std::shared_ptr<Device> CudaDeviceManager::GetDevice(size_t device_index) {
std::lock_guard<std::mutex> lock(devices_mutex_);
if (device_index < devices_.size() && devices_.at(device_index)) {
return devices_.at(device_index);
}
auto device = std::make_shared<CudaDevice>(device_index, this);
if (device_index >= devices_.size()) { devices_.resize(device_index + 1); }
devices_.at(device_index) = device;
return device;
}
size_t CudaDeviceManager::GetDeviceCount(size_t primary_device_index) {
CudaCurrentDeviceGuard guard(primary_device_index);
return this->GetDeviceCount();
}
size_t CudaDeviceManager::GetDeviceCount() {
int count = 0;
hipError_t err = hipGetDeviceCount(&count);
if (err == hipErrorNoDevice || err == hipErrorInsufficientDriver) { return 0; }
OF_CUDA_CHECK(err);
return count;
}
size_t CudaDeviceManager::GetActiveDeviceIndex() {
int device = 0;
OF_CUDA_CHECK(hipGetDevice(&device));
return static_cast<size_t>(device);
}
void CudaDeviceManager::SetActiveDeviceByIndex(size_t device_index) {
OF_CUDA_CHECK(hipSetDevice(static_cast<int>(device_index)));
}
} // namespace ep
} // namespace oneflow
#endif // WITH_ROCM
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EP_ROCM_CUDA_DEVICE_MANAGER_H_
#define ONEFLOW_CORE_EP_ROCM_CUDA_DEVICE_MANAGER_H_
#include "oneflow/core/ep/include/device_manager.h"
#include "oneflow/core/ep/rocm/cuda_device.h"
#ifdef WITH_ROCM
namespace oneflow {
namespace ep {
class CudaDevice;
class CudaDeviceManager : public DeviceManager {
public:
OF_DISALLOW_COPY_AND_MOVE(CudaDeviceManager);
CudaDeviceManager(DeviceManagerRegistry* registry);
~CudaDeviceManager() override;
DeviceManagerRegistry* registry() const override;
std::shared_ptr<Device> GetDevice(size_t device_index) override;
size_t GetDeviceCount(size_t primary_device_index) override;
size_t GetDeviceCount() override;
size_t GetActiveDeviceIndex() override;
void SetActiveDeviceByIndex(size_t device_index) override;
private:
std::mutex devices_mutex_;
std::vector<std::shared_ptr<CudaDevice>> devices_;
DeviceManagerRegistry* registry_;
};
} // namespace ep
} // namespace oneflow
#endif // WITH_ROCM
#endif // ONEFLOW_CORE_EP_ROCM_CUDA_DEVICE_MANAGER_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EP_ROCM_CUDA_DEVICE_MANAGER_H_
#define ONEFLOW_CORE_EP_ROCM_CUDA_DEVICE_MANAGER_H_
#include "oneflow/core/ep/include/device_manager.h"
#include "oneflow/core/ep/rocm/cuda_device.h"
#ifdef WITH_ROCM
namespace oneflow {
namespace ep {
class CudaDevice;
class CudaDeviceManager : public DeviceManager {
public:
OF_DISALLOW_COPY_AND_MOVE(CudaDeviceManager);
CudaDeviceManager(DeviceManagerRegistry* registry);
~CudaDeviceManager() override;
DeviceManagerRegistry* registry() const override;
std::shared_ptr<Device> GetDevice(size_t device_index) override;
size_t GetDeviceCount(size_t primary_device_index) override;
size_t GetDeviceCount() override;
size_t GetActiveDeviceIndex() override;
void SetActiveDeviceByIndex(size_t device_index) override;
private:
std::mutex devices_mutex_;
std::vector<std::shared_ptr<CudaDevice>> devices_;
DeviceManagerRegistry* registry_;
};
} // namespace ep
} // namespace oneflow
#endif // WITH_ROCM
#endif // ONEFLOW_CORE_EP_ROCM_CUDA_DEVICE_MANAGER_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/device_manager_factory.h"
#include "oneflow/core/ep/include/device_manager_registry.h"
#include "oneflow/core/ep/rocm/cuda_device_manager.h"
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#include <miopen/miopen.h>
#include <rccl.h>
namespace oneflow {
namespace ep {
namespace {
std::string GetCudaVersionString(int version) {
return std::to_string(version / 1000) + "." + std::to_string((version % 1000) / 10);
}
bool GetCudnnVersion(size_t* major, size_t* minor, size_t* patch) {
miopenStatus_t status = miopenGetVersion(major, minor, patch);
if (status == miopenStatusSuccess) {
return true;
} else {
LOG(ERROR) << "Failed to get cuDNN version: " << miopenGetErrorString(status);
return false;
}
}
bool GetCudnnVersionString(std::string* version) {
size_t version_major = 0;
size_t version_minor = 0;
size_t version_patch = 0;
if (!GetCudnnVersion(&version_major, &version_minor, &version_patch)) { return false; }
*version = std::to_string(version_major) + "." + std::to_string(version_minor) + "."
+ std::to_string(version_patch);
return true;
}
void CudaDumpVersionInfo() {
{
int cuda_runtime_version = 0;
hipError_t err = hipRuntimeGetVersion(&cuda_runtime_version);
if (err == hipSuccess) {
LOG(INFO) << "CUDA runtime version: " << GetCudaVersionString(cuda_runtime_version);
} else {
LOG(ERROR) << "Failed to get cuda runtime version: " << hipGetErrorString(err);
}
}
{
std::string cudnn_version_string;
if (GetCudnnVersionString(&cudnn_version_string)) {
LOG(INFO) << "cuDNN version: " << cudnn_version_string;
}
}
{
int nccl_version = 0;
ncclResult_t result = ncclGetVersion(&nccl_version);
if (result == ncclSuccess) {
int nccl_version_major =
(nccl_version >= 20900) ? (nccl_version / 10000) : (nccl_version / 1000);
int nccl_version_minor =
(nccl_version >= 20900) ? (nccl_version % 10000) / 100 : (nccl_version % 1000) / 100;
int nccl_version_patch = (nccl_version % 100);
LOG(INFO) << "NCCL version: " << nccl_version_major << "." << nccl_version_minor << "."
<< nccl_version_patch;
} else {
LOG(ERROR) << "Failed to get NCCL version: " << ncclGetErrorString(result);
}
}
}
class CudaDeviceManagerFactory : public DeviceManagerFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(CudaDeviceManagerFactory);
CudaDeviceManagerFactory() = default;
~CudaDeviceManagerFactory() override = default;
std::unique_ptr<DeviceManager> NewDeviceManager(DeviceManagerRegistry* registry) override {
return std::make_unique<CudaDeviceManager>(registry);
}
DeviceType device_type() const override { return DeviceType::kCUDA; }
std::string device_type_name() const override { return "cuda"; }
void DumpVersionInfo() const override { CudaDumpVersionInfo(); }
};
COMMAND(DeviceManagerRegistry::RegisterDeviceManagerFactory(
std::make_unique<CudaDeviceManagerFactory>()))
} // namespace
} // namespace ep
} // namespace oneflow
#endif // WITH_ROCM
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/device_manager_factory.h"
#include "oneflow/core/ep/include/device_manager_registry.h"
#include "oneflow/core/ep/rocm/cuda_device_manager.h"
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#include <miopen/miopen.h>
#include <rccl.h>
namespace oneflow {
namespace ep {
namespace {
std::string GetCudaVersionString(int version) {
return std::to_string(version / 1000) + "." + std::to_string((version % 1000) / 10);
}
bool GetCudnnVersion(size_t* major, size_t* minor, size_t* patch) {
miopenStatus_t status = miopenGetVersion(major, minor, patch);
if (status == miopenStatusSuccess) {
return true;
} else {
LOG(ERROR) << "Failed to get cuDNN version: " << miopenGetErrorString(status);
return false;
}
}
bool GetCudnnVersionString(std::string* version) {
size_t version_major = 0;
size_t version_minor = 0;
size_t version_patch = 0;
if (!GetCudnnVersion(&version_major, &version_minor, &version_patch)) { return false; }
*version = std::to_string(version_major) + "." + std::to_string(version_minor) + "."
+ std::to_string(version_patch);
return true;
}
void CudaDumpVersionInfo() {
{
int cuda_runtime_version = 0;
hipError_t err = hipRuntimeGetVersion(&cuda_runtime_version);
if (err == hipSuccess) {
LOG(INFO) << "CUDA runtime version: " << GetCudaVersionString(cuda_runtime_version);
} else {
LOG(ERROR) << "Failed to get cuda runtime version: " << hipGetErrorString(err);
}
}
{
std::string cudnn_version_string;
if (GetCudnnVersionString(&cudnn_version_string)) {
LOG(INFO) << "cuDNN version: " << cudnn_version_string;
}
}
{
int nccl_version = 0;
ncclResult_t result = ncclGetVersion(&nccl_version);
if (result == ncclSuccess) {
int nccl_version_major =
(nccl_version >= 20900) ? (nccl_version / 10000) : (nccl_version / 1000);
int nccl_version_minor =
(nccl_version >= 20900) ? (nccl_version % 10000) / 100 : (nccl_version % 1000) / 100;
int nccl_version_patch = (nccl_version % 100);
LOG(INFO) << "NCCL version: " << nccl_version_major << "." << nccl_version_minor << "."
<< nccl_version_patch;
} else {
LOG(ERROR) << "Failed to get NCCL version: " << ncclGetErrorString(result);
}
}
}
class CudaDeviceManagerFactory : public DeviceManagerFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(CudaDeviceManagerFactory);
CudaDeviceManagerFactory() = default;
~CudaDeviceManagerFactory() override = default;
std::unique_ptr<DeviceManager> NewDeviceManager(DeviceManagerRegistry* registry) override {
return std::make_unique<CudaDeviceManager>(registry);
}
DeviceType device_type() const override { return DeviceType::kCUDA; }
std::string device_type_name() const override { return "cuda"; }
void DumpVersionInfo() const override { CudaDumpVersionInfo(); }
};
COMMAND(DeviceManagerRegistry::RegisterDeviceManagerFactory(
std::make_unique<CudaDeviceManagerFactory>()))
} // namespace
} // namespace ep
} // namespace oneflow
#endif // WITH_ROCM
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/rocm/cuda_event.h"
#ifdef WITH_ROCM
namespace oneflow {
namespace ep {
CudaEvent::CudaEvent(unsigned int flags) : cuda_event_{} {
OF_CUDA_CHECK(hipEventCreateWithFlags(&cuda_event_, flags));
}
CudaEvent::~CudaEvent() { OF_CUDA_CHECK(hipEventDestroy(cuda_event_)); }
Maybe<bool> CudaEvent::QueryDone() {
hipError_t err = hipEventQuery(cuda_event_);
if (err == hipSuccess) {
return Maybe<bool>(true);
} else if (err == hipErrorNotReady) {
return Maybe<bool>(false);
} else {
return Error::RuntimeError() << hipGetErrorString(err);
}
}
Maybe<void> CudaEvent::Sync() {
hipError_t err = hipEventSynchronize(cuda_event_);
if (err == hipSuccess) {
return Maybe<void>::Ok();
} else {
return Error::RuntimeError() << hipGetErrorString(err);
}
}
hipEvent_t CudaEvent::cuda_event() { return cuda_event_; }
} // namespace ep
} // namespace oneflow
#endif // WITH_ROCM
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/rocm/cuda_event.h"
#ifdef WITH_ROCM
namespace oneflow {
namespace ep {
CudaEvent::CudaEvent(unsigned int flags) : cuda_event_{} {
OF_CUDA_CHECK(hipEventCreateWithFlags(&cuda_event_, flags));
}
CudaEvent::~CudaEvent() { OF_CUDA_CHECK(hipEventDestroy(cuda_event_)); }
Maybe<bool> CudaEvent::QueryDone() {
hipError_t err = hipEventQuery(cuda_event_);
if (err == hipSuccess) {
return Maybe<bool>(true);
} else if (err == hipErrorNotReady) {
return Maybe<bool>(false);
} else {
return Error::RuntimeError() << hipGetErrorString(err);
}
}
Maybe<void> CudaEvent::Sync() {
hipError_t err = hipEventSynchronize(cuda_event_);
if (err == hipSuccess) {
return Maybe<void>::Ok();
} else {
return Error::RuntimeError() << hipGetErrorString(err);
}
}
hipEvent_t CudaEvent::cuda_event() { return cuda_event_; }
} // namespace ep
} // namespace oneflow
#endif // WITH_ROCM
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EP_ROCM_CUDA_EVENT_H_
#define ONEFLOW_CORE_EP_ROCM_CUDA_EVENT_H_
#include "oneflow/core/ep/include/event.h"
#ifdef WITH_ROCM
#include "oneflow/core/device/cuda_util.h"
namespace oneflow {
namespace ep {
class CudaEvent : public Event {
public:
OF_DISALLOW_COPY_AND_MOVE(CudaEvent);
explicit CudaEvent(unsigned int flags);
~CudaEvent() override;
Maybe<bool> QueryDone() override;
Maybe<void> Sync() override;
hipEvent_t cuda_event();
private:
hipEvent_t cuda_event_;
};
} // namespace ep
} // namespace oneflow
#endif // WITH_ROCM
#endif // ONEFLOW_CORE_EP_ROCM_CUDA_EVENT_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EP_ROCM_CUDA_EVENT_H_
#define ONEFLOW_CORE_EP_ROCM_CUDA_EVENT_H_
#include "oneflow/core/ep/include/event.h"
#ifdef WITH_ROCM
#include "oneflow/core/device/cuda_util.h"
namespace oneflow {
namespace ep {
class CudaEvent : public Event {
public:
OF_DISALLOW_COPY_AND_MOVE(CudaEvent);
explicit CudaEvent(unsigned int flags);
~CudaEvent() override;
Maybe<bool> QueryDone() override;
Maybe<void> Sync() override;
hipEvent_t cuda_event();
private:
hipEvent_t cuda_event_;
};
} // namespace ep
} // namespace oneflow
#endif // WITH_ROCM
#endif // ONEFLOW_CORE_EP_ROCM_CUDA_EVENT_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/hardware/node_device_descriptor_manager.h"
#include "oneflow/core/hardware/cuda_device_descriptor.h"
#include "oneflow/core/ep/rocm/cuda_event.h"
#include "oneflow/core/ep/rocm/cuda_device.h"
#ifdef WITH_ROCM
namespace oneflow {
namespace ep {
namespace {
constexpr size_t kDefaultWorkspaceSize = 4 * 1024 * 1024; // 4M
void SetAffinityByDevice(int dev_id) {
auto node_device_desc_mgr = Singleton<hardware::NodeDeviceDescriptorManager>::Get();
if (node_device_desc_mgr == nullptr) { return; }
auto node_device_desc = node_device_desc_mgr->GetLocalNodeDeviceDescriptor();
auto cuda_device = std::dynamic_pointer_cast<const hardware::CudaDeviceDescriptor>(
node_device_desc->GetDevice(hardware::kCudaDeviceDescriptorClassName, dev_id));
if (!cuda_device) { return; }
node_device_desc->Topology()->SetCPUAffinityByPCIBusID(cuda_device->PCIBusID());
node_device_desc->Topology()->SetMemoryAffinityByPCIBusID(cuda_device->PCIBusID());
}
} // namespace
#ifdef WITH_ROCM_GRAPHS
CudaGraphExecutable::CudaGraphExecutable() : graph_exec_(nullptr), dev_(-1) {}
CudaGraphExecutable::~CudaGraphExecutable() { Reset(); }
void CudaGraphExecutable::Update(hipGraph_t graph) {
int dev = -1;
OF_CUDA_CHECK(hipGetDevice(&dev));
if (dev != dev_) { Reset(); }
dev_ = dev;
if (graph_exec_ != nullptr) {
hipGraphExecUpdateResult update_result{};
hipGraphNode_t error_node = nullptr;
OF_CUDA_CHECK(hipGraphExecUpdate(graph_exec_, graph, &error_node, &update_result));
if (update_result == hipGraphExecUpdateSuccess) { return; }
}
Reset();
OF_CUDA_CHECK(hipGraphInstantiate(&graph_exec_, graph, NULL, NULL, 0));
}
void CudaGraphExecutable::Launch(hipStream_t stream) const {
OF_CUDA_CHECK(hipGraphLaunch(graph_exec_, stream));
}
bool CudaGraphExecutable::IsInstantiated() const { return graph_exec_ != nullptr; }
void CudaGraphExecutable::Reset() {
if (graph_exec_ != nullptr) {
CudaCurrentDeviceGuard guard(dev_);
OF_CUDA_CHECK(hipGraphExecDestroy(graph_exec_));
}
}
#endif // WITH_ROCM_GRAPHS
CudaStream::CudaStream(CudaDevice* device)
: device_index_(device->device_index()), device_(device) {
CudaCurrentDeviceGuard guard(device_index_);
// cuda_stream
OF_CUDA_CHECK(hipStreamCreate(&cuda_stream_));
// cublas_handle
OF_CUBLAS_CHECK(hipblasCreate(&cublas_handle_));
OF_CUBLAS_CHECK(hipblasSetStream(cublas_handle_, cuda_stream_));
workspace_size_ = kDefaultWorkspaceSize;
OF_CUDA_CHECK(hipMalloc(&workspace_, workspace_size_));
OF_CUDNN_CHECK(hipdnnCreate(&cudnn_handle_));
OF_CUDNN_CHECK(hipdnnSetStream(cudnn_handle_, cuda_stream_));
}
CudaStream::~CudaStream() {
CudaCurrentDeviceGuard guard(device_index_);
OF_CUDA_CHECK(hipStreamSynchronize(cuda_stream_));
OF_CUDNN_CHECK(hipdnnDestroy(cudnn_handle_));
OF_CUBLAS_CHECK(hipblasDestroy(cublas_handle_));
OF_CUDA_CHECK(hipStreamDestroy(cuda_stream_));
OF_CUDA_CHECK(hipFree(workspace_));
}
Maybe<void> CudaStream::OnExecutionContextSetup() {
OF_CUDA_CHECK(hipSetDevice(device_index_));
SetAffinityByDevice(device_index_);
return Maybe<void>::Ok();
}
Maybe<void> CudaStream::OnExecutionContextTeardown() { return Maybe<void>::Ok(); }
DeviceType CudaStream::device_type() const { return DeviceType::kCUDA; }
CudaDevice* CudaStream::device() const { return device_; }
Maybe<void> CudaStream::Sync() {
hipError_t err = hipStreamSynchronize(cuda_stream_);
if (err == hipSuccess) {
return Maybe<void>::Ok();
} else {
return Error::RuntimeError() << hipGetErrorString(err) << " (" << err << ") ";
}
}
void CudaStream::RecordEvent(Event* event) {
auto* cuda_event = static_cast<CudaEvent*>(event); // NOLINT
OF_CUDA_CHECK(hipEventRecord(cuda_event->cuda_event(), cuda_stream_));
}
hipStream_t CudaStream::cuda_stream() const { return cuda_stream_; }
hipblasHandle_t CudaStream::cublas_handle() const { return cublas_handle_; }
void* CudaStream::cublas_workspace() const { return workspace_; }
size_t CudaStream::cublas_workspace_size() const { return workspace_size_; }
hipdnnHandle_t CudaStream::cudnn_handle() const { return cudnn_handle_; }
const hipDeviceProp_t& CudaStream::device_properties() const { return device_->properties(); }
int CudaStream::cuda_arch() const {
return device_->properties().major * 100 + device_->properties().minor * 10;
}
#ifdef WITH_ROCM_GRAPHS
void CudaStream::BeginGraphCapture() {
CHECK(!is_graph_capturing_);
is_graph_capturing_ = true;
OF_CUDA_CHECK(hipStreamBeginCapture(cuda_stream_, hipStreamCaptureModeThreadLocal));
}
void CudaStream::EndGraphCapture(CudaGraphExecutable* executable) {
hipGraph_t graph = nullptr;
OF_CUDA_CHECK(hipStreamEndCapture(cuda_stream_, &graph));
executable->Update(graph);
OF_CUDA_CHECK(hipGraphDestroy(graph));
is_graph_capturing_ = false;
}
bool CudaStream::IsGraphCapturing() const { return is_graph_capturing_; }
void CudaStream::LaunchGraph(const CudaGraphExecutable* executable) {
executable->Launch(cuda_stream_);
}
#endif // WITH_ROCM_GRAPHS
} // namespace ep
} // namespace oneflow
#endif // WITH_ROCM
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/hardware/node_device_descriptor_manager.h"
#include "oneflow/core/hardware/cuda_device_descriptor.h"
#include "oneflow/core/ep/rocm/cuda_event.h"
#include "oneflow/core/ep/rocm/cuda_device.h"
#ifdef WITH_ROCM
namespace oneflow {
namespace ep {
namespace {
constexpr size_t kDefaultWorkspaceSize = 4 * 1024 * 1024; // 4M
void SetAffinityByDevice(int dev_id) {
auto node_device_desc_mgr = Singleton<hardware::NodeDeviceDescriptorManager>::Get();
if (node_device_desc_mgr == nullptr) { return; }
auto node_device_desc = node_device_desc_mgr->GetLocalNodeDeviceDescriptor();
auto cuda_device = std::dynamic_pointer_cast<const hardware::CudaDeviceDescriptor>(
node_device_desc->GetDevice(hardware::kCudaDeviceDescriptorClassName, dev_id));
if (!cuda_device) { return; }
node_device_desc->Topology()->SetCPUAffinityByPCIBusID(cuda_device->PCIBusID());
node_device_desc->Topology()->SetMemoryAffinityByPCIBusID(cuda_device->PCIBusID());
}
} // namespace
#ifdef WITH_ROCM_GRAPHS
CudaGraphExecutable::CudaGraphExecutable() : graph_exec_(nullptr), dev_(-1) {}
CudaGraphExecutable::~CudaGraphExecutable() { Reset(); }
void CudaGraphExecutable::Update(hipGraph_t graph) {
int dev = -1;
OF_CUDA_CHECK(hipGetDevice(&dev));
if (dev != dev_) { Reset(); }
dev_ = dev;
if (graph_exec_ != nullptr) {
hipGraphExecUpdateResult update_result{};
hipGraphNode_t error_node = nullptr;
OF_CUDA_CHECK(hipGraphExecUpdate(graph_exec_, graph, &error_node, &update_result));
if (update_result == hipGraphExecUpdateSuccess) { return; }
}
Reset();
OF_CUDA_CHECK(hipGraphInstantiate(&graph_exec_, graph, NULL, NULL, 0));
}
void CudaGraphExecutable::Launch(hipStream_t stream) const {
OF_CUDA_CHECK(hipGraphLaunch(graph_exec_, stream));
}
bool CudaGraphExecutable::IsInstantiated() const { return graph_exec_ != nullptr; }
void CudaGraphExecutable::Reset() {
if (graph_exec_ != nullptr) {
CudaCurrentDeviceGuard guard(dev_);
OF_CUDA_CHECK(hipGraphExecDestroy(graph_exec_));
}
}
#endif // WITH_ROCM_GRAPHS
CudaStream::CudaStream(CudaDevice* device)
: device_index_(device->device_index()), device_(device) {
CudaCurrentDeviceGuard guard(device_index_);
// cuda_stream
OF_CUDA_CHECK(hipStreamCreate(&cuda_stream_));
// cublas_handle
OF_CUBLAS_CHECK(hipblasCreate(&cublas_handle_));
OF_CUBLAS_CHECK(hipblasSetStream(cublas_handle_, cuda_stream_));
workspace_size_ = kDefaultWorkspaceSize;
OF_CUDA_CHECK(hipMalloc(&workspace_, workspace_size_));
OF_CUDNN_CHECK(hipdnnCreate(&cudnn_handle_));
OF_CUDNN_CHECK(hipdnnSetStream(cudnn_handle_, cuda_stream_));
}
CudaStream::~CudaStream() {
CudaCurrentDeviceGuard guard(device_index_);
OF_CUDA_CHECK(hipStreamSynchronize(cuda_stream_));
OF_CUDNN_CHECK(hipdnnDestroy(cudnn_handle_));
OF_CUBLAS_CHECK(hipblasDestroy(cublas_handle_));
OF_CUDA_CHECK(hipStreamDestroy(cuda_stream_));
OF_CUDA_CHECK(hipFree(workspace_));
}
Maybe<void> CudaStream::OnExecutionContextSetup() {
OF_CUDA_CHECK(hipSetDevice(device_index_));
SetAffinityByDevice(device_index_);
return Maybe<void>::Ok();
}
Maybe<void> CudaStream::OnExecutionContextTeardown() { return Maybe<void>::Ok(); }
DeviceType CudaStream::device_type() const { return DeviceType::kCUDA; }
CudaDevice* CudaStream::device() const { return device_; }
Maybe<void> CudaStream::Sync() {
hipError_t err = hipStreamSynchronize(cuda_stream_);
if (err == hipSuccess) {
return Maybe<void>::Ok();
} else {
return Error::RuntimeError() << hipGetErrorString(err) << " (" << err << ") ";
}
}
void CudaStream::RecordEvent(Event* event) {
auto* cuda_event = static_cast<CudaEvent*>(event); // NOLINT
OF_CUDA_CHECK(hipEventRecord(cuda_event->cuda_event(), cuda_stream_));
}
hipStream_t CudaStream::cuda_stream() const { return cuda_stream_; }
hipblasHandle_t CudaStream::cublas_handle() const { return cublas_handle_; }
void* CudaStream::cublas_workspace() const { return workspace_; }
size_t CudaStream::cublas_workspace_size() const { return workspace_size_; }
hipdnnHandle_t CudaStream::cudnn_handle() const { return cudnn_handle_; }
const hipDeviceProp_t& CudaStream::device_properties() const { return device_->properties(); }
int CudaStream::cuda_arch() const {
return device_->properties().major * 100 + device_->properties().minor * 10;
}
#ifdef WITH_ROCM_GRAPHS
void CudaStream::BeginGraphCapture() {
CHECK(!is_graph_capturing_);
is_graph_capturing_ = true;
OF_CUDA_CHECK(hipStreamBeginCapture(cuda_stream_, hipStreamCaptureModeThreadLocal));
}
void CudaStream::EndGraphCapture(CudaGraphExecutable* executable) {
hipGraph_t graph = nullptr;
OF_CUDA_CHECK(hipStreamEndCapture(cuda_stream_, &graph));
executable->Update(graph);
OF_CUDA_CHECK(hipGraphDestroy(graph));
is_graph_capturing_ = false;
}
bool CudaStream::IsGraphCapturing() const { return is_graph_capturing_; }
void CudaStream::LaunchGraph(const CudaGraphExecutable* executable) {
executable->Launch(cuda_stream_);
}
#endif // WITH_ROCM_GRAPHS
} // namespace ep
} // namespace oneflow
#endif // WITH_ROCM
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EP_ROCM_CUDA_STREAM_H_
#define ONEFLOW_CORE_EP_ROCM_CUDA_STREAM_H_
#include "oneflow/core/ep/include/stream.h"
#include "oneflow/core/ep/rocm/cuda_device.h"
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#include "oneflow/core/hipdnn/hipdnn.h"
// #if CUDA_VERSION >= 11000
// #define WITH_ROCM_GRAPHS
// #endif // CUDA_VERSION >= 11000
#include "oneflow/core/device/cuda_util.h"
namespace oneflow {
namespace ep {
class CudaDevice;
#ifdef WITH_ROCM_GRAPHS
class CudaGraphExecutable {
public:
OF_DISALLOW_COPY_AND_MOVE(CudaGraphExecutable);
CudaGraphExecutable();
~CudaGraphExecutable();
void Update(hipGraph_t graph);
void Launch(hipStream_t stream) const;
bool IsInstantiated() const;
private:
void Reset();
hipGraphExec_t graph_exec_;
int dev_;
};
#endif // WITH_ROCM_GRAPHS
struct CudaLaunchConfig {
dim3 grid_dim;
dim3 block_dim;
size_t shared_mem_size;
CudaLaunchConfig() : grid_dim{}, block_dim{}, shared_mem_size(0) {}
CudaLaunchConfig(unsigned int grid_size, unsigned int block_size, size_t shared_mem_size)
: grid_dim(grid_size), block_dim(block_size), shared_mem_size(shared_mem_size) {}
};
class CudaStream : public Stream {
public:
OF_DISALLOW_COPY_AND_MOVE(CudaStream);
explicit CudaStream(CudaDevice* device);
~CudaStream() override;
static constexpr uint32_t kDefaultBlockSize = 256;
DeviceType device_type() const override;
CudaDevice* device() const override;
Maybe<void> Sync() override;
void RecordEvent(Event* event) override;
Maybe<void> OnExecutionContextSetup() override;
Maybe<void> OnExecutionContextTeardown() override;
hipStream_t cuda_stream() const;
hipblasHandle_t cublas_handle() const;
// #if CUDA_VERSION >= 10010
// cublasLtHandle_t cublas_lt_handle() const;
// #endif
hipdnnHandle_t cudnn_handle() const;
void* cublas_workspace() const;
size_t cublas_workspace_size() const;
const hipDeviceProp_t& device_properties() const;
int cuda_arch() const;
void InitLaunchConfigWithWaves(CudaLaunchConfig* config, size_t elem_cnt, size_t block_size,
size_t max_waves) const {
const uint32_t max_grid_size = max_waves * device_properties().multiProcessorCount
* (device_properties().maxThreadsPerMultiProcessor / block_size);
const uint32_t grid_size =
std::min<uint32_t>(max_grid_size, (elem_cnt + block_size - 1) / block_size);
config->grid_dim = dim3(grid_size);
config->block_dim = dim3(block_size);
config->shared_mem_size = 0;
}
#ifdef __HIPCC__
template<typename... Params, typename... Args>
void LaunchKernel(void (*kernel)(Params...), const CudaLaunchConfig& launch_config,
Args... args) {
kernel<<<launch_config.grid_dim, launch_config.block_dim, launch_config.shared_mem_size,
cuda_stream()>>>(args...);
}
template<typename... Params, typename... Args>
void LaunchKernel(void (*kernel)(Params...), size_t elem_cnt, size_t max_waves, Args... args) {
constexpr uint32_t block_size = kDefaultBlockSize;
CudaLaunchConfig config{};
InitLaunchConfigWithWaves(&config, elem_cnt, block_size, max_waves);
LaunchKernel(kernel, config, args...);
}
template<typename... Params, typename... Args>
void LaunchKernelDefaultWaves(void (*kernel)(Params...), size_t elem_cnt, Args... args) {
const size_t default_waves = 32;
LaunchKernel(kernel, elem_cnt, default_waves, args...);
}
#endif // __HIPCC__
#ifdef WITH_ROCM_GRAPHS
void BeginGraphCapture();
void EndGraphCapture(CudaGraphExecutable* executable);
bool IsGraphCapturing() const;
void LaunchGraph(const CudaGraphExecutable* executable);
#endif // WITH_ROCM_GRAPHS
private:
hipStream_t cuda_stream_{};
hipblasHandle_t cublas_handle_{};
// #if CUDA_VERSION >= 10010
// cublasLtHandle_t cublas_lt_handle_{};
// #endif
hipdnnHandle_t cudnn_handle_{};
int device_index_;
void* workspace_{};
size_t workspace_size_{};
#ifdef WITH_ROCM_GRAPHS
bool is_graph_capturing_{};
#endif // WITH_ROCM_GRAPHS
CudaDevice* device_;
};
} // namespace ep
} // namespace oneflow
#endif // WITH_ROCM
#endif // ONEFLOW_CORE_EP_ROCM_CUDA_STREAM_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EP_ROCM_CUDA_STREAM_H_
#define ONEFLOW_CORE_EP_ROCM_CUDA_STREAM_H_
#include "oneflow/core/ep/include/stream.h"
#include "oneflow/core/ep/rocm/cuda_device.h"
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#include "oneflow/core/hipdnn/hipdnn.h"
// #if CUDA_VERSION >= 11000
// #define WITH_ROCM_GRAPHS
// #endif // CUDA_VERSION >= 11000
#include "oneflow/core/device/cuda_util.h"
namespace oneflow {
namespace ep {
class CudaDevice;
#ifdef WITH_ROCM_GRAPHS
class CudaGraphExecutable {
public:
OF_DISALLOW_COPY_AND_MOVE(CudaGraphExecutable);
CudaGraphExecutable();
~CudaGraphExecutable();
void Update(hipGraph_t graph);
void Launch(hipStream_t stream) const;
bool IsInstantiated() const;
private:
void Reset();
hipGraphExec_t graph_exec_;
int dev_;
};
#endif // WITH_ROCM_GRAPHS
struct CudaLaunchConfig {
dim3 grid_dim;
dim3 block_dim;
size_t shared_mem_size;
CudaLaunchConfig() : grid_dim{}, block_dim{}, shared_mem_size(0) {}
CudaLaunchConfig(unsigned int grid_size, unsigned int block_size, size_t shared_mem_size)
: grid_dim(grid_size), block_dim(block_size), shared_mem_size(shared_mem_size) {}
};
class CudaStream : public Stream {
public:
OF_DISALLOW_COPY_AND_MOVE(CudaStream);
explicit CudaStream(CudaDevice* device);
~CudaStream() override;
static constexpr uint32_t kDefaultBlockSize = 256;
DeviceType device_type() const override;
CudaDevice* device() const override;
Maybe<void> Sync() override;
void RecordEvent(Event* event) override;
Maybe<void> OnExecutionContextSetup() override;
Maybe<void> OnExecutionContextTeardown() override;
hipStream_t cuda_stream() const;
hipblasHandle_t cublas_handle() const;
// #if CUDA_VERSION >= 10010
// cublasLtHandle_t cublas_lt_handle() const;
// #endif
hipdnnHandle_t cudnn_handle() const;
void* cublas_workspace() const;
size_t cublas_workspace_size() const;
const hipDeviceProp_t& device_properties() const;
int cuda_arch() const;
void InitLaunchConfigWithWaves(CudaLaunchConfig* config, size_t elem_cnt, size_t block_size,
size_t max_waves) const {
const uint32_t max_grid_size = max_waves * device_properties().multiProcessorCount
* (device_properties().maxThreadsPerMultiProcessor / block_size);
const uint32_t grid_size =
std::min<uint32_t>(max_grid_size, (elem_cnt + block_size - 1) / block_size);
config->grid_dim = dim3(grid_size);
config->block_dim = dim3(block_size);
config->shared_mem_size = 0;
}
#ifdef __HIPCC__
template<typename... Params, typename... Args>
void LaunchKernel(void (*kernel)(Params...), const CudaLaunchConfig& launch_config,
Args... args) {
kernel<<<launch_config.grid_dim, launch_config.block_dim, launch_config.shared_mem_size,
cuda_stream()>>>(args...);
}
template<typename... Params, typename... Args>
void LaunchKernel(void (*kernel)(Params...), size_t elem_cnt, size_t max_waves, Args... args) {
constexpr uint32_t block_size = kDefaultBlockSize;
CudaLaunchConfig config{};
InitLaunchConfigWithWaves(&config, elem_cnt, block_size, max_waves);
LaunchKernel(kernel, config, args...);
}
template<typename... Params, typename... Args>
void LaunchKernelDefaultWaves(void (*kernel)(Params...), size_t elem_cnt, Args... args) {
const size_t default_waves = 32;
LaunchKernel(kernel, elem_cnt, default_waves, args...);
}
#endif // __HIPCC__
#ifdef WITH_ROCM_GRAPHS
void BeginGraphCapture();
void EndGraphCapture(CudaGraphExecutable* executable);
bool IsGraphCapturing() const;
void LaunchGraph(const CudaGraphExecutable* executable);
#endif // WITH_ROCM_GRAPHS
private:
hipStream_t cuda_stream_{};
hipblasHandle_t cublas_handle_{};
// #if CUDA_VERSION >= 10010
// cublasLtHandle_t cublas_lt_handle_{};
// #endif
hipdnnHandle_t cudnn_handle_{};
int device_index_;
void* workspace_{};
size_t workspace_size_{};
#ifdef WITH_ROCM_GRAPHS
bool is_graph_capturing_{};
#endif // WITH_ROCM_GRAPHS
CudaDevice* device_;
};
} // namespace ep
} // namespace oneflow
#endif // WITH_ROCM
#endif // ONEFLOW_CORE_EP_ROCM_CUDA_STREAM_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/ep/include/primitive/add.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/device/cuda_pseudo_bfloat16.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace {
template<typename... Args>
struct AddFunctor;
template<typename T>
struct AddFunctor<T> {
__device__ T operator()(T x) const { return x; }
};
template<typename T, typename U, typename... Args>
struct AddFunctor<T, U, Args...> {
__device__ T operator()(T x0, U x1, Args... xs) const {
return x0 + AddFunctor<U, Args...>()(x1, xs...);
}
};
template<typename T, typename... Args>
__global__ void AddGpu(const Args*... srcs, T* dst, size_t count) {
CUDA_1D_KERNEL_LOOP_T(size_t, i, count) { dst[i] = AddFunctor<Args...>()(srcs[i]...); }
}
template<typename T, typename... Args>
void LaunchAddGpu(hipStream_t stream, const Args*... srcs, T* dst, size_t count) {
AddGpu<T, Args...>
<<<BlocksNum4ThreadsNum(count), kCudaThreadsNumPerBlock, 0, stream>>>(srcs..., dst, count);
}
template<typename T>
void DispatchLaunch(hipStream_t stream, const T* const* srcs, size_t arity, T* dst, size_t count) {
if (arity == 0) {
OF_CUDA_CHECK(hipMemsetAsync(dst, 0, count * sizeof(T), stream));
} else if (arity == 1) {
OF_CUDA_CHECK(hipMemcpyAsync(dst, srcs[0], count * sizeof(T), hipMemcpyDefault, stream));
} else if (arity == 2) {
OF_CUDA_CHECK((cuda::elementwise::Binary<AddFunctor<T, T>, T, T, T>(
AddFunctor<T, T>(), count, dst, srcs[0], srcs[1], stream)));
} else if (arity == 3) {
OF_CUDA_CHECK((cuda::elementwise::Ternary<AddFunctor<T, T, T>, T, T, T, T>(
AddFunctor<T, T, T>(), count, dst, srcs[0], srcs[1], srcs[2], stream)));
} else if (arity == 4) {
LaunchAddGpu<T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], dst, count);
} else if (arity == 5) {
LaunchAddGpu<T, T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4], dst, count);
} else if (arity == 6) {
LaunchAddGpu<T, T, T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4], srcs[5],
dst, count);
} else if (arity == 7) {
LaunchAddGpu<T, T, T, T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4],
srcs[5], srcs[6], dst, count);
} else if (arity == 8) {
LaunchAddGpu<T, T, T, T, T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4],
srcs[5], srcs[6], srcs[7], dst, count);
} else {
DispatchLaunch(stream, srcs + 7, arity - 7, dst, count);
LaunchAddGpu<T, T, T, T, T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4],
srcs[5], srcs[6], dst, dst, count);
}
}
template<typename T>
class AddImpl : public Add {
public:
OF_DISALLOW_COPY_AND_MOVE(AddImpl);
AddImpl() = default;
~AddImpl() override = default;
using Add::Launch;
void Launch(Stream* stream, const void* const* srcs, size_t arity, void* dst,
size_t count) override {
hipStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream();
DispatchLaunch(cuda_stream, reinterpret_cast<const T* const*>(srcs), arity,
reinterpret_cast<T*>(dst), count);
}
};
template<typename T>
std::unique_ptr<Add> NewAdd() {
return std::unique_ptr<Add>(new AddImpl<T>());
}
class AddFactoryImpl : public AddFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(AddFactoryImpl);
AddFactoryImpl() = default;
~AddFactoryImpl() override = default;
std::unique_ptr<Add> New(DataType data_type) override {
#define MAKE_NEW_ADD_ENTRY(type_cpp, type_proto) {type_proto, NewAdd<type_cpp>},
static const std::map<DataType, std::function<std::unique_ptr<Add>()>> new_add_handle{
OF_PP_FOR_EACH_TUPLE(MAKE_NEW_ADD_ENTRY, CUDA_PRIMITIVE_ALL_TYPE_SEQ)};
#undef MAKE_NEW_ADD_ENTRY
const auto it = new_add_handle.find(data_type);
if (it != new_add_handle.end()) {
return it->second();
} else {
return nullptr;
}
}
};
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, AddFactory, AddFactoryImpl);
} // namespace
} // namespace primitive
} // namespace ep
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/ep/include/primitive/add.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/device/cuda_pseudo_bfloat16.h"
namespace oneflow {
namespace ep {
namespace primitive {
namespace {
template<typename... Args>
struct AddFunctor;
template<typename T>
struct AddFunctor<T> {
__device__ T operator()(T x) const { return x; }
};
template<typename T, typename U, typename... Args>
struct AddFunctor<T, U, Args...> {
__device__ T operator()(T x0, U x1, Args... xs) const {
return x0 + AddFunctor<U, Args...>()(x1, xs...);
}
};
template<typename T, typename... Args>
__global__ void AddGpu(const Args*... srcs, T* dst, size_t count) {
CUDA_1D_KERNEL_LOOP_T(size_t, i, count) { dst[i] = AddFunctor<Args...>()(srcs[i]...); }
}
template<typename T, typename... Args>
void LaunchAddGpu(hipStream_t stream, const Args*... srcs, T* dst, size_t count) {
AddGpu<T, Args...>
<<<BlocksNum4ThreadsNum(count), kCudaThreadsNumPerBlock, 0, stream>>>(srcs..., dst, count);
}
template<typename T>
void DispatchLaunch(hipStream_t stream, const T* const* srcs, size_t arity, T* dst, size_t count) {
if (arity == 0) {
OF_CUDA_CHECK(hipMemsetAsync(dst, 0, count * sizeof(T), stream));
} else if (arity == 1) {
OF_CUDA_CHECK(hipMemcpyAsync(dst, srcs[0], count * sizeof(T), hipMemcpyDefault, stream));
} else if (arity == 2) {
OF_CUDA_CHECK((cuda::elementwise::Binary<AddFunctor<T, T>, T, T, T>(
AddFunctor<T, T>(), count, dst, srcs[0], srcs[1], stream)));
} else if (arity == 3) {
OF_CUDA_CHECK((cuda::elementwise::Ternary<AddFunctor<T, T, T>, T, T, T, T>(
AddFunctor<T, T, T>(), count, dst, srcs[0], srcs[1], srcs[2], stream)));
} else if (arity == 4) {
LaunchAddGpu<T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], dst, count);
} else if (arity == 5) {
LaunchAddGpu<T, T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4], dst, count);
} else if (arity == 6) {
LaunchAddGpu<T, T, T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4], srcs[5],
dst, count);
} else if (arity == 7) {
LaunchAddGpu<T, T, T, T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4],
srcs[5], srcs[6], dst, count);
} else if (arity == 8) {
LaunchAddGpu<T, T, T, T, T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4],
srcs[5], srcs[6], srcs[7], dst, count);
} else {
DispatchLaunch(stream, srcs + 7, arity - 7, dst, count);
LaunchAddGpu<T, T, T, T, T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4],
srcs[5], srcs[6], dst, dst, count);
}
}
template<typename T>
class AddImpl : public Add {
public:
OF_DISALLOW_COPY_AND_MOVE(AddImpl);
AddImpl() = default;
~AddImpl() override = default;
using Add::Launch;
void Launch(Stream* stream, const void* const* srcs, size_t arity, void* dst,
size_t count) override {
hipStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream();
DispatchLaunch(cuda_stream, reinterpret_cast<const T* const*>(srcs), arity,
reinterpret_cast<T*>(dst), count);
}
};
template<typename T>
std::unique_ptr<Add> NewAdd() {
return std::unique_ptr<Add>(new AddImpl<T>());
}
class AddFactoryImpl : public AddFactory {
public:
OF_DISALLOW_COPY_AND_MOVE(AddFactoryImpl);
AddFactoryImpl() = default;
~AddFactoryImpl() override = default;
std::unique_ptr<Add> New(DataType data_type) override {
#define MAKE_NEW_ADD_ENTRY(type_cpp, type_proto) {type_proto, NewAdd<type_cpp>},
static const std::map<DataType, std::function<std::unique_ptr<Add>()>> new_add_handle{
OF_PP_FOR_EACH_TUPLE(MAKE_NEW_ADD_ENTRY, CUDA_PRIMITIVE_ALL_TYPE_SEQ)};
#undef MAKE_NEW_ADD_ENTRY
const auto it = new_add_handle.find(data_type);
if (it != new_add_handle.end()) {
return it->second();
} else {
return nullptr;
}
}
};
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, AddFactory, AddFactoryImpl);
} // namespace
} // namespace primitive
} // namespace ep
} // namespace oneflow
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment