Commit 8f7de847 authored by yuguo960516yuguo's avatar yuguo960516yuguo
Browse files

dtk

parent f262efc9
Pipeline #248 failed with stages
in 0 seconds
# OneFlow # OneFlow
OneFlow is a deep learning framework designed to be **user-friendly, scalable and efficient**. With OneFlow, it is easy to: **OneFlow is a performance-centered and open-source deep learning framework.**
- program a model with **PyTorch-like API**
- scale a model to n-dimensional-parallel/distributed execution with the **Global View API**
- accelerate/deploy a model with the **Static Graph Compiler**.
[![Simple CI](https://github.com/Oneflow-Inc/oneflow/actions/workflows/simple.yml/badge.svg)](https://github.com/Oneflow-Inc/oneflow/actions/workflows/simple.yml) [![Simple CI](https://github.com/Oneflow-Inc/oneflow/actions/workflows/simple.yml/badge.svg)](https://github.com/Oneflow-Inc/oneflow/actions/workflows/simple.yml)
[![Nightly Docker Image](https://github.com/Oneflow-Inc/docker-images/actions/workflows/oneflow-nightly.yml/badge.svg)](https://github.com/Oneflow-Inc/docker-images/actions/workflows/oneflow-nightly.yml) [![Nightly Docker Image](https://github.com/Oneflow-Inc/docker-images/actions/workflows/oneflow-nightly.yml/badge.svg)](https://github.com/Oneflow-Inc/docker-images/actions/workflows/oneflow-nightly.yml)
...@@ -12,8 +9,10 @@ OneFlow is a deep learning framework designed to be **user-friendly, scalable an ...@@ -12,8 +9,10 @@ OneFlow is a deep learning framework designed to be **user-friendly, scalable an
## Latest News ## Latest News
- Version 0.8.0 is out! - Version 0.7.0 is out!
- [Full changelog](https://github.com/Oneflow-Inc/oneflow/releases/tag/v0.8.0) - Introducing global tensor
- Semi-auto parallelization has landed
- [Full changelog](https://github.com/Oneflow-Inc/oneflow/releases/tag/v0.7.0)
## Publication ## Publication
...@@ -36,7 +35,7 @@ OneFlow is a deep learning framework designed to be **user-friendly, scalable an ...@@ -36,7 +35,7 @@ OneFlow is a deep learning framework designed to be **user-friendly, scalable an
### System Requirements ### System Requirements
- Linux. As for now, there is no pre-built release for macOS, Windows. - Linux. As for now, there is no pre-built release for macOS, Windows.
- Python 3.7, 3.8, 3.9, 3.10 - Python 3.6, 3.7, 3.8, 3.9, 3.10
- (**Highly recommended**) Upgrade pip - (**Highly recommended**) Upgrade pip
``` ```
...@@ -54,7 +53,7 @@ OneFlow is a deep learning framework designed to be **user-friendly, scalable an ...@@ -54,7 +53,7 @@ OneFlow is a deep learning framework designed to be **user-friendly, scalable an
- To install latest stable release of OneFlow with CUDA support: - To install latest stable release of OneFlow with CUDA support:
```bash ```bash
python3 -m pip install oneflow python3 -m pip install -f https://release.oneflow.info oneflow==0.7.0+cu102
``` ```
- To install nightly release of OneFlow with CUDA support: - To install nightly release of OneFlow with CUDA support:
...@@ -67,7 +66,7 @@ OneFlow is a deep learning framework designed to be **user-friendly, scalable an ...@@ -67,7 +66,7 @@ OneFlow is a deep learning framework designed to be **user-friendly, scalable an
- Stable - Stable
```bash ```bash
python3 -m pip install --find-links https://release.oneflow.info oneflow==0.8.0+[PLATFORM] python3 -m pip install --find-links https://release.oneflow.info oneflow==0.7.0+[PLATFORM]
``` ```
- Nightly - Nightly
``` ```
......
# Monkey patch to not ship libjvm.so in pypi wheels
import sys
from auditwheel.main import main
from auditwheel.policy import _POLICIES as POLICIES
# libjvm is loaded dynamically; do not include it
for p in POLICIES:
p['lib_whitelist'].append('librccl.so.1')
p['lib_whitelist'].append('libhipblas.so.0')
p['lib_whitelist'].append('libhiprand.so.1')
p['lib_whitelist'].append('librocrand.so.1')
p['lib_whitelist'].append('libMIOpen.so.1')
p['lib_whitelist'].append('libgalaxyhip.so.4')
p['lib_whitelist'].append('librocm_smi64.so.2')
p['lib_whitelist'].append('librocsolver.so.0 ')
p['lib_whitelist'].append('librocblas.so.0')
if __name__ == "__main__":
sys.exit(main())
# Monkey patch to not ship libjvm.so in pypi wheels
import sys
from auditwheel.main import main
from auditwheel.policy import _POLICIES as POLICIES
# libjvm is loaded dynamically; do not include it
for p in POLICIES:
p['lib_whitelist'].append('librccl.so.1')
p['lib_whitelist'].append('libhipblas.so.0')
p['lib_whitelist'].append('libhiprand.so.1')
p['lib_whitelist'].append('librocrand.so.1')
p['lib_whitelist'].append('libMIOpen.so.1')
p['lib_whitelist'].append('libgalaxyhip.so.5')
p['lib_whitelist'].append('librocm_smi64.so.2')
p['lib_whitelist'].append('librocsolver.so.0 ')
p['lib_whitelist'].append('librocblas.so.0')
if __name__ == "__main__":
sys.exit(main())
...@@ -328,6 +328,17 @@ if(BUILD_PYTHON OR BUILD_CPP_API) ...@@ -328,6 +328,17 @@ if(BUILD_PYTHON OR BUILD_CPP_API)
endif() endif()
endif() endif()
if (BUILD_ROCM)
# AMD compiler fails to compile these three files with '-O1/2/3'.
# The value of `COMPILE_OPTIONS` target property is added after CMAKE_<LANG>_FLAGS_<CONFIG>,
# so '-O0' will override '-O1/2/3'.
set_source_files_properties(${PROJECT_SOURCE_DIR}/oneflow/user/kernels/median_with_indices_kernel.hip.cpp
${PROJECT_SOURCE_DIR}/oneflow/user/kernels/radix_sort_top_k_kernel.hip.cpp
${PROJECT_SOURCE_DIR}/oneflow/user/kernels/arg_sort_kernel.hip.cpp
# ${PROJECT_SOURCE_DIR}/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math.hip.cpp
PROPERTIES COMPILE_OPTIONS "-O0")
endif()
if(BUILD_PYTHON) if(BUILD_PYTHON)
# py ext lib # py ext lib
......
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "hip/hip_runtime.h" #include "hip/hip_runtime.h"
#include "oneflow/core/embedding/cached_key_value_store.h" #include "oneflow/core/embedding/cached_key_value_store.h"
#include "oneflow/core/ep/rocm/cuda_stream.h" #include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/ep/include/device_manager_registry.h" #include "oneflow/core/ep/include/device_manager_registry.h"
namespace oneflow { namespace oneflow {
namespace embedding { namespace embedding {
namespace { namespace {
template<typename Key, typename Elem> template<typename Key, typename Elem>
__global__ void PostStoreGetKernel(uint32_t num_cache_missing, uint32_t num_store_missing, __global__ void PostStoreGetKernel(uint32_t num_cache_missing, uint32_t num_store_missing,
uint32_t num_elems_per_value, uint32_t num_elems_per_value,
const uint32_t* cache_missing_indices, const uint32_t* cache_missing_indices,
const uint32_t* store_missing_indices, const Elem* store_values, const uint32_t* store_missing_indices, const Elem* store_values,
Elem* values, uint32_t* missing_indices) { Elem* values, uint32_t* missing_indices) {
const uint32_t num_cache_missing_elem = num_cache_missing * num_elems_per_value; const uint32_t num_cache_missing_elem = num_cache_missing * num_elems_per_value;
CUDA_1D_KERNEL_LOOP_T(uint32_t, i, num_cache_missing_elem) { CUDA_1D_KERNEL_LOOP_T(uint32_t, i, num_cache_missing_elem) {
const uint32_t value_index = i / num_elems_per_value; const uint32_t value_index = i / num_elems_per_value;
const uint32_t elem_index = i - value_index * num_elems_per_value; const uint32_t elem_index = i - value_index * num_elems_per_value;
values[cache_missing_indices[value_index] * num_elems_per_value + elem_index] = store_values[i]; values[cache_missing_indices[value_index] * num_elems_per_value + elem_index] = store_values[i];
} }
CUDA_1D_KERNEL_LOOP_T(uint32_t, i, num_store_missing) { CUDA_1D_KERNEL_LOOP_T(uint32_t, i, num_store_missing) {
missing_indices[i] = cache_missing_indices[store_missing_indices[i]]; missing_indices[i] = cache_missing_indices[store_missing_indices[i]];
} }
} }
template<typename Key, typename Elem> template<typename Key, typename Elem>
class CacheKeyValueStoreImpl : public KeyValueStore { class CacheKeyValueStoreImpl : public KeyValueStore {
public: public:
OF_DISALLOW_COPY_AND_MOVE(CacheKeyValueStoreImpl); OF_DISALLOW_COPY_AND_MOVE(CacheKeyValueStoreImpl);
CacheKeyValueStoreImpl(std::unique_ptr<KeyValueStore>&& store, std::unique_ptr<Cache>&& cache) CacheKeyValueStoreImpl(std::unique_ptr<KeyValueStore>&& store, std::unique_ptr<Cache>&& cache)
: store_(std::move(store)), cache_(std::move(cache)), synced_(true), max_query_length_(0) { : store_(std::move(store)), cache_(std::move(cache)), synced_(true), max_query_length_(0) {
OF_CUDA_CHECK(hipGetDevice(&device_index_)); OF_CUDA_CHECK(hipGetDevice(&device_index_));
CHECK_EQ(store_->KeySize(), cache_->KeySize()); CHECK_EQ(store_->KeySize(), cache_->KeySize());
CHECK_EQ(store_->ValueSize(), cache_->ValueSize()); CHECK_EQ(store_->ValueSize(), cache_->ValueSize());
OF_CUDA_CHECK(hipMalloc(&num_buffer_, sizeof(uint32_t))); OF_CUDA_CHECK(hipMalloc(&num_buffer_, sizeof(uint32_t)));
OF_CUDA_CHECK(hipMallocHost(reinterpret_cast<void **>(&host_num_buffer_), sizeof(uint32_t))); OF_CUDA_CHECK(hipMallocHost(reinterpret_cast<void **>(&host_num_buffer_), sizeof(uint32_t)));
num_elems_per_value_ = store_->ValueSize() / sizeof(Elem); num_elems_per_value_ = store_->ValueSize() / sizeof(Elem);
} }
~CacheKeyValueStoreImpl() { ~CacheKeyValueStoreImpl() {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
OF_CUDA_CHECK(hipFree(num_buffer_)); OF_CUDA_CHECK(hipFree(num_buffer_));
OF_CUDA_CHECK(hipHostFree(host_num_buffer_)); OF_CUDA_CHECK(hipHostFree(host_num_buffer_));
if (max_query_length_ != 0) { if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipFree(keys_buffer_)); OF_CUDA_CHECK(hipFree(keys_buffer_));
OF_CUDA_CHECK(hipFree(values_buffer_)); OF_CUDA_CHECK(hipFree(values_buffer_));
OF_CUDA_CHECK(hipFree(indices_buffer0_)); OF_CUDA_CHECK(hipFree(indices_buffer0_));
OF_CUDA_CHECK(hipFree(indices_buffer1_)); OF_CUDA_CHECK(hipFree(indices_buffer1_));
} }
cache_.reset(); cache_.reset();
store_.reset(); store_.reset();
} }
uint32_t KeySize() const override { return store_->KeySize(); } uint32_t KeySize() const override { return store_->KeySize(); }
uint32_t ValueSize() const override { return store_->ValueSize(); } uint32_t ValueSize() const override { return store_->ValueSize(); }
uint32_t MaxQueryLength() const override { return max_query_length_; } uint32_t MaxQueryLength() const override { return max_query_length_; }
void ReserveQueryLength(uint32_t query_length) override { void ReserveQueryLength(uint32_t query_length) override {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
if (query_length <= max_query_length_) { return; } if (query_length <= max_query_length_) { return; }
if (query_length > cache_->MaxQueryLength()) { cache_->ReserveQueryLength(query_length); } if (query_length > cache_->MaxQueryLength()) { cache_->ReserveQueryLength(query_length); }
if (query_length > store_->MaxQueryLength()) { store_->ReserveQueryLength(query_length); } if (query_length > store_->MaxQueryLength()) { store_->ReserveQueryLength(query_length); }
if (max_query_length_ != 0) { if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipFree(keys_buffer_)); OF_CUDA_CHECK(hipFree(keys_buffer_));
OF_CUDA_CHECK(hipFree(values_buffer_)); OF_CUDA_CHECK(hipFree(values_buffer_));
OF_CUDA_CHECK(hipFree(indices_buffer0_)); OF_CUDA_CHECK(hipFree(indices_buffer0_));
OF_CUDA_CHECK(hipFree(indices_buffer1_)); OF_CUDA_CHECK(hipFree(indices_buffer1_));
} }
OF_CUDA_CHECK(hipMalloc(&keys_buffer_, query_length * store_->KeySize())); OF_CUDA_CHECK(hipMalloc(&keys_buffer_, query_length * store_->KeySize()));
OF_CUDA_CHECK(hipMalloc(&values_buffer_, query_length * store_->ValueSize())); OF_CUDA_CHECK(hipMalloc(&values_buffer_, query_length * store_->ValueSize()));
OF_CUDA_CHECK(hipMalloc(&indices_buffer0_, query_length * sizeof(uint32_t))); OF_CUDA_CHECK(hipMalloc(&indices_buffer0_, query_length * sizeof(uint32_t)));
OF_CUDA_CHECK(hipMalloc(&indices_buffer1_, query_length * sizeof(uint32_t))); OF_CUDA_CHECK(hipMalloc(&indices_buffer1_, query_length * sizeof(uint32_t)));
max_query_length_ = query_length; max_query_length_ = query_length;
} }
void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values, void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values,
uint32_t* n_missing, uint32_t* missing_indices) override; uint32_t* n_missing, uint32_t* missing_indices) override;
void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values, void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values,
uint8_t* mask) override; uint8_t* mask) override;
void Put(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values) override; void Put(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values) override;
void FusedHalfUpdatePut(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values, void FusedHalfUpdatePut(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values,
const void* update, const float* lr, float scale) override; const void* update, const float* lr, float scale) override;
bool IsFusionSupported() override { bool IsFusionSupported() override {
return cache_->Policy() == CacheOptions::Policy::kFull return cache_->Policy() == CacheOptions::Policy::kFull
&& cache_->ValueType() == DataType::kFloat; && cache_->ValueType() == DataType::kFloat;
} }
bool SnapshotExists(const std::string& name) override; bool SnapshotExists(const std::string& name) override;
void LoadSnapshot(const std::string& name) override; void LoadSnapshot(const std::string& name) override;
void SaveSnapshot(const std::string& name) override; void SaveSnapshot(const std::string& name) override;
void LoadSnapshot(const std::string& name, void LoadSnapshot(const std::string& name,
const std::function<void(KVIterator* iter)>& Hook) override; const std::function<void(KVIterator* iter)>& Hook) override;
private: private:
void SyncCacheToStore(); void SyncCacheToStore();
std::unique_ptr<KeyValueStore> store_; std::unique_ptr<KeyValueStore> store_;
std::unique_ptr<Cache> cache_; std::unique_ptr<Cache> cache_;
uint32_t* num_buffer_{}; uint32_t* num_buffer_{};
uint32_t* host_num_buffer_{}; uint32_t* host_num_buffer_{};
Key* keys_buffer_{}; Key* keys_buffer_{};
Elem* values_buffer_{}; Elem* values_buffer_{};
uint32_t* indices_buffer0_{}; uint32_t* indices_buffer0_{};
uint32_t* indices_buffer1_{}; uint32_t* indices_buffer1_{};
int device_index_{}; int device_index_{};
uint32_t max_query_length_; uint32_t max_query_length_;
uint32_t num_elems_per_value_{}; uint32_t num_elems_per_value_{};
std::recursive_mutex mutex_; std::recursive_mutex mutex_;
bool synced_; bool synced_;
}; };
template<typename Key, typename Elem> template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void CacheKeyValueStoreImpl<Key, Elem>::Get(ep::Stream* stream, uint32_t num_keys, const void* keys,
void* values, uint32_t* n_missing, void* values, uint32_t* n_missing,
uint32_t* missing_indices) { uint32_t* missing_indices) {
std::lock_guard<std::recursive_mutex> lock(mutex_); std::lock_guard<std::recursive_mutex> lock(mutex_);
auto cuda_stream = stream->As<ep::CudaStream>(); auto cuda_stream = stream->As<ep::CudaStream>();
if (cache_->Policy() == CacheOptions::Policy::kFull) { if (cache_->Policy() == CacheOptions::Policy::kFull) {
cache_->Get(stream, num_keys, keys, values, n_missing, keys_buffer_, missing_indices); cache_->Get(stream, num_keys, keys, values, n_missing, keys_buffer_, missing_indices);
return; return;
} else { } else {
cache_->Get(stream, num_keys, keys, values, num_buffer_, keys_buffer_, indices_buffer0_); cache_->Get(stream, num_keys, keys, values, num_buffer_, keys_buffer_, indices_buffer0_);
} }
OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t), hipMemcpyDefault, OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream())); cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync()); CHECK_JUST(cuda_stream->Sync());
const uint32_t num_cache_missing = *host_num_buffer_; const uint32_t num_cache_missing = *host_num_buffer_;
if (num_cache_missing == 0) { if (num_cache_missing == 0) {
OF_CUDA_CHECK(hipMemsetAsync(n_missing, 0, sizeof(uint32_t), OF_CUDA_CHECK(hipMemsetAsync(n_missing, 0, sizeof(uint32_t),
stream->As<ep::CudaStream>()->cuda_stream())); stream->As<ep::CudaStream>()->cuda_stream()));
return; return;
} }
store_->Get(stream, num_cache_missing, keys_buffer_, values_buffer_, n_missing, indices_buffer1_); store_->Get(stream, num_cache_missing, keys_buffer_, values_buffer_, n_missing, indices_buffer1_);
OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, n_missing, sizeof(uint32_t), hipMemcpyDefault, OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, n_missing, sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream())); cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync()); CHECK_JUST(cuda_stream->Sync());
const uint32_t num_store_missing = *host_num_buffer_; const uint32_t num_store_missing = *host_num_buffer_;
RUN_CUDA_KERNEL((PostStoreGetKernel<Key, Elem>), stream, num_cache_missing * num_elems_per_value_, RUN_CUDA_KERNEL((PostStoreGetKernel<Key, Elem>), stream, num_cache_missing * num_elems_per_value_,
num_cache_missing, num_store_missing, num_elems_per_value_, indices_buffer0_, num_cache_missing, num_store_missing, num_elems_per_value_, indices_buffer0_,
indices_buffer1_, values_buffer_, static_cast<Elem*>(values), missing_indices); indices_buffer1_, values_buffer_, static_cast<Elem*>(values), missing_indices);
} }
template<typename Key, typename Elem> template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void CacheKeyValueStoreImpl<Key, Elem>::Get(ep::Stream* stream, uint32_t num_keys, const void* keys,
void* values, uint8_t* mask) { void* values, uint8_t* mask) {
std::lock_guard<std::recursive_mutex> lock(mutex_); std::lock_guard<std::recursive_mutex> lock(mutex_);
if (cache_->Policy() == CacheOptions::Policy::kFull) { if (cache_->Policy() == CacheOptions::Policy::kFull) {
cache_->Get(stream, num_keys, keys, values, mask); cache_->Get(stream, num_keys, keys, values, mask);
return; return;
} else { } else {
UNIMPLEMENTED(); UNIMPLEMENTED();
} }
} }
template<typename Key, typename Elem> template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::Put(ep::Stream* stream, uint32_t num_keys, const void* keys, void CacheKeyValueStoreImpl<Key, Elem>::Put(ep::Stream* stream, uint32_t num_keys, const void* keys,
const void* values) { const void* values) {
std::lock_guard<std::recursive_mutex> lock(mutex_); std::lock_guard<std::recursive_mutex> lock(mutex_);
synced_ = false; synced_ = false;
auto cuda_stream = stream->As<ep::CudaStream>(); auto cuda_stream = stream->As<ep::CudaStream>();
cache_->Put(stream, num_keys, keys, values, num_buffer_, keys_buffer_, values_buffer_); cache_->Put(stream, num_keys, keys, values, num_buffer_, keys_buffer_, values_buffer_);
if (cache_->Policy() == CacheOptions::Policy::kFull) { return; } if (cache_->Policy() == CacheOptions::Policy::kFull) { return; }
OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t), hipMemcpyDefault, OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream())); cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync()); CHECK_JUST(cuda_stream->Sync());
store_->Put(stream, *host_num_buffer_, keys_buffer_, values_buffer_); store_->Put(stream, *host_num_buffer_, keys_buffer_, values_buffer_);
} }
template<typename Key, typename Elem> template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::FusedHalfUpdatePut(ep::Stream* stream, uint32_t num_keys, void CacheKeyValueStoreImpl<Key, Elem>::FusedHalfUpdatePut(ep::Stream* stream, uint32_t num_keys,
const void* keys, const void* values, const void* keys, const void* values,
const void* update, const float* lr, const void* update, const float* lr,
float scale) { float scale) {
std::lock_guard<std::recursive_mutex> lock(mutex_); std::lock_guard<std::recursive_mutex> lock(mutex_);
if (cache_->Policy() != CacheOptions::Policy::kFull || cache_->ValueType() != DataType::kFloat) { if (cache_->Policy() != CacheOptions::Policy::kFull || cache_->ValueType() != DataType::kFloat) {
UNIMPLEMENTED(); UNIMPLEMENTED();
} }
synced_ = false; synced_ = false;
cache_->FusedHalfUpdatePut(stream, num_keys, keys, values, update, lr, scale, num_buffer_, cache_->FusedHalfUpdatePut(stream, num_keys, keys, values, update, lr, scale, num_buffer_,
keys_buffer_, values_buffer_); keys_buffer_, values_buffer_);
} }
template<typename Key, typename Elem> template<typename Key, typename Elem>
bool CacheKeyValueStoreImpl<Key, Elem>::SnapshotExists(const std::string& name) { bool CacheKeyValueStoreImpl<Key, Elem>::SnapshotExists(const std::string& name) {
return store_->SnapshotExists(name); return store_->SnapshotExists(name);
} }
template<typename Key, typename Elem> template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::LoadSnapshot(const std::string& name) { void CacheKeyValueStoreImpl<Key, Elem>::LoadSnapshot(const std::string& name) {
LoadSnapshot(name, nullptr); LoadSnapshot(name, nullptr);
} }
template<typename Key, typename Elem> template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::LoadSnapshot( void CacheKeyValueStoreImpl<Key, Elem>::LoadSnapshot(
const std::string& name, const std::function<void(KVIterator* iter)>& Hook) { const std::string& name, const std::function<void(KVIterator* iter)>& Hook) {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
std::lock_guard<std::recursive_mutex> lock(mutex_); std::lock_guard<std::recursive_mutex> lock(mutex_);
CHECK_GT(max_query_length_, 0); CHECK_GT(max_query_length_, 0);
cache_->Clear(); cache_->Clear();
auto device = auto device =
Singleton<ep::DeviceManagerRegistry>::Get()->GetDevice(DeviceType::kCUDA, device_index_); Singleton<ep::DeviceManagerRegistry>::Get()->GetDevice(DeviceType::kCUDA, device_index_);
CHECK(device); CHECK(device);
auto* stream = device->CreateStream(); auto* stream = device->CreateStream();
store_->LoadSnapshot(name, [&](KVIterator* iter) { store_->LoadSnapshot(name, [&](KVIterator* iter) {
if (cache_->Policy() == CacheOptions::Policy::kFull) { if (cache_->Policy() == CacheOptions::Policy::kFull) {
auto* cuda_stream = stream->As<ep::CudaStream>(); auto* cuda_stream = stream->As<ep::CudaStream>();
while (true) { while (true) {
iter->NextN(stream, max_query_length_, num_buffer_, keys_buffer_, values_buffer_); iter->NextN(stream, max_query_length_, num_buffer_, keys_buffer_, values_buffer_);
OF_CUDA_CHECK(hipDeviceSynchronize()); OF_CUDA_CHECK(hipDeviceSynchronize());
OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t), OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t),
hipMemcpyDefault, cuda_stream->cuda_stream())); hipMemcpyDefault, cuda_stream->cuda_stream()));
CHECK_JUST(stream->Sync()); CHECK_JUST(stream->Sync());
if (*host_num_buffer_ == 0) { return; } if (*host_num_buffer_ == 0) { return; }
cache_->Put(stream, *host_num_buffer_, keys_buffer_, values_buffer_, num_buffer_, nullptr, cache_->Put(stream, *host_num_buffer_, keys_buffer_, values_buffer_, num_buffer_, nullptr,
nullptr); nullptr);
OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t), OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t),
hipMemcpyDefault, cuda_stream->cuda_stream())); hipMemcpyDefault, cuda_stream->cuda_stream()));
CHECK_JUST(stream->Sync()); CHECK_JUST(stream->Sync());
CHECK_EQ(*host_num_buffer_, 0); CHECK_EQ(*host_num_buffer_, 0);
} }
} }
if (Hook) { if (Hook) {
iter->Reset(); iter->Reset();
Hook(iter); Hook(iter);
} }
}); });
device->DestroyStream(stream); device->DestroyStream(stream);
store_->LoadSnapshot(name); store_->LoadSnapshot(name);
} }
template<typename Key, typename Elem> template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::SaveSnapshot(const std::string& name) { void CacheKeyValueStoreImpl<Key, Elem>::SaveSnapshot(const std::string& name) {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
std::lock_guard<std::recursive_mutex> lock(mutex_); std::lock_guard<std::recursive_mutex> lock(mutex_);
SyncCacheToStore(); SyncCacheToStore();
store_->SaveSnapshot(name); store_->SaveSnapshot(name);
} }
template<typename Key, typename Elem> template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::SyncCacheToStore() { void CacheKeyValueStoreImpl<Key, Elem>::SyncCacheToStore() {
if (synced_) { return; } if (synced_) { return; }
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
auto device = auto device =
Singleton<ep::DeviceManagerRegistry>::Get()->GetDevice(DeviceType::kCUDA, device_index_); Singleton<ep::DeviceManagerRegistry>::Get()->GetDevice(DeviceType::kCUDA, device_index_);
CHECK(device); CHECK(device);
auto* stream = device->CreateStream(); auto* stream = device->CreateStream();
auto* cuda_stream = stream->As<ep::CudaStream>(); auto* cuda_stream = stream->As<ep::CudaStream>();
const uint64_t dump_capacity = cache_->DumpCapacity(); const uint64_t dump_capacity = cache_->DumpCapacity();
CHECK_GT(max_query_length_, 0); CHECK_GT(max_query_length_, 0);
for (uint64_t start_key_index = 0; start_key_index < dump_capacity; for (uint64_t start_key_index = 0; start_key_index < dump_capacity;
start_key_index += max_query_length_) { start_key_index += max_query_length_) {
cache_->Dump(stream, start_key_index, cache_->Dump(stream, start_key_index,
std::min(start_key_index + max_query_length_, dump_capacity), num_buffer_, std::min(start_key_index + max_query_length_, dump_capacity), num_buffer_,
keys_buffer_, values_buffer_); keys_buffer_, values_buffer_);
OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t), OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t),
hipMemcpyDefault, cuda_stream->cuda_stream())); hipMemcpyDefault, cuda_stream->cuda_stream()));
CHECK_JUST(stream->Sync()); CHECK_JUST(stream->Sync());
if (*host_num_buffer_ == 0) { continue; } if (*host_num_buffer_ == 0) { continue; }
store_->Put(stream, *host_num_buffer_, keys_buffer_, values_buffer_); store_->Put(stream, *host_num_buffer_, keys_buffer_, values_buffer_);
CHECK_JUST(stream->Sync()); CHECK_JUST(stream->Sync());
} }
device->DestroyStream(stream); device->DestroyStream(stream);
synced_ = true; synced_ = true;
} }
template<typename Key> template<typename Key>
std::unique_ptr<KeyValueStore> DispatchElemType(std::unique_ptr<KeyValueStore>&& store, std::unique_ptr<KeyValueStore> DispatchElemType(std::unique_ptr<KeyValueStore>&& store,
std::unique_ptr<Cache>&& cache) { std::unique_ptr<Cache>&& cache) {
const uint32_t value_size = store->ValueSize(); const uint32_t value_size = store->ValueSize();
if (value_size % sizeof(uint4) == 0) { if (value_size % sizeof(uint4) == 0) {
return std::unique_ptr<KeyValueStore>( return std::unique_ptr<KeyValueStore>(
new CacheKeyValueStoreImpl<Key, uint4>(std::move(store), std::move(cache))); new CacheKeyValueStoreImpl<Key, uint4>(std::move(store), std::move(cache)));
} else if (value_size % sizeof(uint64_t) == 0) { } else if (value_size % sizeof(uint64_t) == 0) {
return std::unique_ptr<KeyValueStore>( return std::unique_ptr<KeyValueStore>(
new CacheKeyValueStoreImpl<Key, uint64_t>(std::move(store), std::move(cache))); new CacheKeyValueStoreImpl<Key, uint64_t>(std::move(store), std::move(cache)));
} else if (value_size % sizeof(uint32_t) == 0) { } else if (value_size % sizeof(uint32_t) == 0) {
return std::unique_ptr<KeyValueStore>( return std::unique_ptr<KeyValueStore>(
new CacheKeyValueStoreImpl<Key, uint32_t>(std::move(store), std::move(cache))); new CacheKeyValueStoreImpl<Key, uint32_t>(std::move(store), std::move(cache)));
} else if (value_size % sizeof(uint16_t) == 0) { } else if (value_size % sizeof(uint16_t) == 0) {
return std::unique_ptr<KeyValueStore>( return std::unique_ptr<KeyValueStore>(
new CacheKeyValueStoreImpl<Key, uint16_t>(std::move(store), std::move(cache))); new CacheKeyValueStoreImpl<Key, uint16_t>(std::move(store), std::move(cache)));
} else { } else {
return std::unique_ptr<KeyValueStore>( return std::unique_ptr<KeyValueStore>(
new CacheKeyValueStoreImpl<Key, uint8_t>(std::move(store), std::move(cache))); new CacheKeyValueStoreImpl<Key, uint8_t>(std::move(store), std::move(cache)));
} }
} }
std::unique_ptr<KeyValueStore> DispatchKeyType(std::unique_ptr<KeyValueStore>&& store, std::unique_ptr<KeyValueStore> DispatchKeyType(std::unique_ptr<KeyValueStore>&& store,
std::unique_ptr<Cache>&& cache) { std::unique_ptr<Cache>&& cache) {
const uint32_t key_size = store->KeySize(); const uint32_t key_size = store->KeySize();
if (key_size == 4) { if (key_size == 4) {
return DispatchElemType<uint32_t>(std::move(store), std::move(cache)); return DispatchElemType<uint32_t>(std::move(store), std::move(cache));
} else if (key_size == 8) { } else if (key_size == 8) {
return DispatchElemType<uint64_t>(std::move(store), std::move(cache)); return DispatchElemType<uint64_t>(std::move(store), std::move(cache));
} else { } else {
UNIMPLEMENTED(); UNIMPLEMENTED();
return nullptr; return nullptr;
} }
} }
} // namespace } // namespace
std::unique_ptr<KeyValueStore> NewCachedKeyValueStore(std::unique_ptr<KeyValueStore>&& store, std::unique_ptr<KeyValueStore> NewCachedKeyValueStore(std::unique_ptr<KeyValueStore>&& store,
std::unique_ptr<Cache>&& cache) { std::unique_ptr<Cache>&& cache) {
return DispatchKeyType(std::move(store), std::move(cache)); return DispatchKeyType(std::move(store), std::move(cache));
} }
} // namespace embedding } // namespace embedding
} // namespace oneflow } // namespace oneflow
\ No newline at end of file
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "hip/hip_runtime.h" #include "hip/hip_runtime.h"
#include "oneflow/core/embedding/full_cache.h" #include "oneflow/core/embedding/full_cache.h"
#include "oneflow/core/device/cuda_util.h" #include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/embedding/hash_functions.hip.h" #include "oneflow/core/embedding/hash_functions.hip.h"
#include "oneflow/core/hip/atomic.hip.h" #include "oneflow/core/hip/atomic.hip.h"
namespace oneflow { namespace oneflow {
namespace embedding { namespace embedding {
using Key32 = unsigned int; using Key32 = unsigned int;
using Key64 = unsigned long long int; using Key64 = unsigned long long int;
using Key128 = ulonglong2; using Key128 = ulonglong2;
namespace { namespace {
template<typename Key, typename Index> template<typename Key, typename Index>
__device__ bool TryGetOrInsert(Key* entry_key, volatile Index* entry_index, Index* table_size, __device__ bool TryGetOrInsert(Key* entry_key, volatile Index* entry_index, Index* table_size,
Key key, Index* out) { Key key, Index* out) {
Key key_hi = (key | 0x1); Key key_hi = (key | 0x1);
Key key_lo = (key & 0x1); Key key_lo = (key & 0x1);
Index index_plus_one = 0; Index index_plus_one = 0;
Key old_entry_key = cuda::atomic::CAS(entry_key, static_cast<Key>(0), key_hi); Key old_entry_key = cuda::atomic::CAS(entry_key, static_cast<Key>(0), key_hi);
while (index_plus_one == 0) { while (index_plus_one == 0) {
if (old_entry_key == static_cast<Key>(0)) { if (old_entry_key == static_cast<Key>(0)) {
Index index = cuda::atomic::Add(table_size, static_cast<Index>(1)); Index index = cuda::atomic::Add(table_size, static_cast<Index>(1));
index_plus_one = index + 1; index_plus_one = index + 1;
*entry_index = ((index_plus_one << 1U) | key_lo); *entry_index = ((index_plus_one << 1U) | key_lo);
*out = index_plus_one; *out = index_plus_one;
return true; return true;
} else if (old_entry_key == key_hi) { } else if (old_entry_key == key_hi) {
const Index entry_index_val = *entry_index; const Index entry_index_val = *entry_index;
if (entry_index_val == 0) { if (entry_index_val == 0) {
// do nothing // do nothing
} else if ((entry_index_val & 0x1) == key_lo) { } else if ((entry_index_val & 0x1) == key_lo) {
*out = (entry_index_val >> 1U); *out = (entry_index_val >> 1U);
return true; return true;
} else { } else {
return false; return false;
} }
} else { } else {
return false; return false;
} }
} }
return false; return false;
} }
template<typename Key, typename Index> template<typename Key, typename Index>
__device__ bool GetOrInsertOne(const size_t capacity, Key* table_keys, Index* table_indices, __device__ bool GetOrInsertOne(const size_t capacity, Key* table_keys, Index* table_indices,
Index* table_size, Key key, size_t hash, Index* out) { Index* table_size, Key key, size_t hash, Index* out) {
const size_t start_idx = hash % capacity; const size_t start_idx = hash % capacity;
for (size_t count = 0; count < capacity; ++count) { for (size_t count = 0; count < capacity; ++count) {
const size_t idx = (start_idx + count) % capacity; const size_t idx = (start_idx + count) % capacity;
Key* entry_key = table_keys + idx; Key* entry_key = table_keys + idx;
Index* entry_index = table_indices + idx; Index* entry_index = table_indices + idx;
if (TryGetOrInsert<Key, Index>(entry_key, entry_index, table_size, key, out)) { return true; } if (TryGetOrInsert<Key, Index>(entry_key, entry_index, table_size, key, out)) { return true; }
} }
return false; return false;
} }
template<typename Key, typename Index> template<typename Key, typename Index>
__device__ bool GetOne(const size_t capacity, Key* table_keys, Index* table_indices, Key key, __device__ bool GetOne(const size_t capacity, Key* table_keys, Index* table_indices, Key key,
size_t hash, Index* out) { size_t hash, Index* out) {
const size_t start_idx = hash % capacity; const size_t start_idx = hash % capacity;
for (size_t count = 0; count < capacity; ++count) { for (size_t count = 0; count < capacity; ++count) {
const size_t idx = (start_idx + count) % capacity; const size_t idx = (start_idx + count) % capacity;
Key entry_key = table_keys[idx]; Key entry_key = table_keys[idx];
Key entry_index = table_indices[idx]; Key entry_index = table_indices[idx];
Key key_hi = (key | 0x1); Key key_hi = (key | 0x1);
Key key_lo = (key & 0x1); Key key_lo = (key & 0x1);
if (entry_key == 0) { break; } if (entry_key == 0) { break; }
if (entry_key == key_hi) { if (entry_key == key_hi) {
if ((entry_index & 0x1) == key_lo) { if ((entry_index & 0x1) == key_lo) {
*out = (entry_index >> 1U); *out = (entry_index >> 1U);
return true; return true;
} }
} }
} }
*out = 0; *out = 0;
return false; return false;
} }
template<typename Key, typename Index> template<typename Key, typename Index>
__global__ void OrdinalEncodeKernel(uint64_t capacity, Key* table_keys, Index* table_indices, __global__ void OrdinalEncodeKernel(uint64_t capacity, Key* table_keys, Index* table_indices,
Index* table_size, uint32_t num_keys, const Key* keys, Index* table_size, uint32_t num_keys, const Key* keys,
Index* context) { Index* context) {
CUDA_1D_KERNEL_LOOP(i, num_keys) { CUDA_1D_KERNEL_LOOP(i, num_keys) {
Key key = keys[i]; Key key = keys[i];
uint64_t hash = FullCacheHash()(key); uint64_t hash = FullCacheHash()(key);
bool success = GetOrInsertOne<Key, Index>(capacity, table_keys, table_indices, table_size, key, bool success = GetOrInsertOne<Key, Index>(capacity, table_keys, table_indices, table_size, key,
hash, context + i); hash, context + i);
assert(success); assert(success);
} }
} }
template<typename Key, typename Index> template<typename Key, typename Index>
__global__ void OrdinalEncodeLookupKernel(uint64_t capacity, Key* table_keys, Index* table_indices, __global__ void OrdinalEncodeLookupKernel(uint64_t capacity, Key* table_keys, Index* table_indices,
uint32_t num_keys, const Key* keys, Index* context) { uint32_t num_keys, const Key* keys, Index* context) {
CUDA_1D_KERNEL_LOOP(i, num_keys) { CUDA_1D_KERNEL_LOOP(i, num_keys) {
Key key = keys[i]; Key key = keys[i];
uint64_t hash = FullCacheHash()(key); uint64_t hash = FullCacheHash()(key);
GetOne<Key, Index>(capacity, table_keys, table_indices, key, hash, context + i); GetOne<Key, Index>(capacity, table_keys, table_indices, key, hash, context + i);
} }
} }
template<typename Key, typename Index> template<typename Key, typename Index>
__global__ void OrdinalEncodeDumpKernel(const Key* table_keys, const Index* table_indices, __global__ void OrdinalEncodeDumpKernel(const Key* table_keys, const Index* table_indices,
uint64_t start_key_index, uint64_t end_key_index, uint64_t start_key_index, uint64_t end_key_index,
uint32_t* n_dumped, Key* keys, Index* context) { uint32_t* n_dumped, Key* keys, Index* context) {
CUDA_1D_KERNEL_LOOP(i, (end_key_index - start_key_index)) { CUDA_1D_KERNEL_LOOP(i, (end_key_index - start_key_index)) {
Key entry_key = table_keys[i + start_key_index]; Key entry_key = table_keys[i + start_key_index];
Index entry_index = table_indices[i + start_key_index]; Index entry_index = table_indices[i + start_key_index];
if (entry_index != 0) { if (entry_index != 0) {
uint32_t index = cuda::atomic::Add(n_dumped, static_cast<uint32_t>(1)); uint32_t index = cuda::atomic::Add(n_dumped, static_cast<uint32_t>(1));
keys[index] = ((entry_key ^ 0x1) | (entry_index & 0x1)); keys[index] = ((entry_key ^ 0x1) | (entry_index & 0x1));
context[index] = (entry_index >> 1U); context[index] = (entry_index >> 1U);
} }
} }
} }
template<typename Key, typename Elem, typename Index, bool return_value> template<typename Key, typename Elem, typename Index, bool return_value>
__global__ void LookupKernel(uint32_t value_length, const Elem* cache_values, __global__ void LookupKernel(uint32_t value_length, const Elem* cache_values,
uint32_t values_elem_cnt, const Key* keys, const Index* context, uint32_t values_elem_cnt, const Key* keys, const Index* context,
Elem* values, uint32_t* n_missing, Key* missing_keys, Elem* values, uint32_t* n_missing, Key* missing_keys,
uint32_t* missing_indices) { uint32_t* missing_indices) {
CUDA_1D_KERNEL_LOOP(i, values_elem_cnt) { CUDA_1D_KERNEL_LOOP(i, values_elem_cnt) {
const uint64_t key_id = i / value_length; const uint64_t key_id = i / value_length;
const uint64_t ctx = context[key_id]; const uint64_t ctx = context[key_id];
const uint64_t row_id = ctx - 1; const uint64_t row_id = ctx - 1;
const uint64_t col_id = i - key_id * value_length; const uint64_t col_id = i - key_id * value_length;
if (ctx == 0) { if (ctx == 0) {
const Key missing_key = keys[key_id]; const Key missing_key = keys[key_id];
if (col_id == 0) { if (col_id == 0) {
const uint32_t old_n_missing = cuda::atomic::Add(n_missing, static_cast<uint32_t>(1)); const uint32_t old_n_missing = cuda::atomic::Add(n_missing, static_cast<uint32_t>(1));
missing_keys[old_n_missing] = missing_key; missing_keys[old_n_missing] = missing_key;
missing_indices[old_n_missing] = key_id; missing_indices[old_n_missing] = key_id;
} }
continue; continue;
} }
if (return_value) { values[i] = cache_values[row_id * value_length + col_id]; } if (return_value) { values[i] = cache_values[row_id * value_length + col_id]; }
} }
} }
template<typename Key, typename Elem, typename Index, uint32_t block_size> template<typename Key, typename Elem, typename Index, uint32_t block_size>
__global__ void EncodeLookupKernel(uint32_t value_length, const Elem* cache_values, __global__ void EncodeLookupKernel(uint32_t value_length, const Elem* cache_values,
uint32_t values_elem_cnt, const Key* keys, const Index* context, uint32_t values_elem_cnt, const Key* keys, const Index* context,
Elem* values, uint32_t* n_missing, Key* missing_keys, Elem* values, uint32_t* n_missing, Key* missing_keys,
uint32_t* missing_indices, const size_t capacity, uint32_t* missing_indices, const size_t capacity,
Key* table_keys, Index* table_indices) { Key* table_keys, Index* table_indices) {
constexpr uint32_t warp_size = 32; constexpr uint32_t warp_size = 32;
constexpr uint32_t n_warp_per_block = block_size / warp_size; constexpr uint32_t n_warp_per_block = block_size / warp_size;
const uint32_t warp_id = threadIdx.x / warp_size; const uint32_t warp_id = threadIdx.x / warp_size;
const uint32_t lane_id = threadIdx.x % warp_size; const uint32_t lane_id = threadIdx.x % warp_size;
const uint32_t global_warp_id = blockIdx.x * n_warp_per_block + warp_id; const uint32_t global_warp_id = blockIdx.x * n_warp_per_block + warp_id;
const uint32_t global_n_warp = gridDim.x * n_warp_per_block; const uint32_t global_n_warp = gridDim.x * n_warp_per_block;
const uint32_t n_keys = values_elem_cnt / value_length; const uint32_t n_keys = values_elem_cnt / value_length;
__shared__ Key batch_keys[n_warp_per_block][warp_size]; __shared__ Key batch_keys[n_warp_per_block][warp_size];
__shared__ Index batch_row_ids[n_warp_per_block][warp_size]; __shared__ Index batch_row_ids[n_warp_per_block][warp_size];
__shared__ Key batch_missing_keys[n_warp_per_block][warp_size]; __shared__ Key batch_missing_keys[n_warp_per_block][warp_size];
__shared__ uint32_t batch_missing_indices[n_warp_per_block][warp_size]; __shared__ uint32_t batch_missing_indices[n_warp_per_block][warp_size];
__shared__ uint32_t batch_n_missing[n_warp_per_block]; __shared__ uint32_t batch_n_missing[n_warp_per_block];
for (uint32_t batch_start = global_warp_id * warp_size; batch_start < n_keys; for (uint32_t batch_start = global_warp_id * warp_size; batch_start < n_keys;
batch_start += global_n_warp * warp_size) { batch_start += global_n_warp * warp_size) {
const uint32_t batch_n_key = min(n_keys - batch_start, warp_size); const uint32_t batch_n_key = min(n_keys - batch_start, warp_size);
if (lane_id == 0) { batch_n_missing[warp_id] = 0; } if (lane_id == 0) { batch_n_missing[warp_id] = 0; }
__syncthreads(); __syncthreads();
const uint32_t key_offset = batch_start + lane_id; const uint32_t key_offset = batch_start + lane_id;
if (key_offset < n_keys) { if (key_offset < n_keys) {
const Key key = keys[batch_start + lane_id]; const Key key = keys[batch_start + lane_id];
const uint64_t hash = FullCacheHash()(key); const uint64_t hash = FullCacheHash()(key);
Index row; Index row;
GetOne<Key, Index>(capacity, table_keys, table_indices, key, hash, &row); GetOne<Key, Index>(capacity, table_keys, table_indices, key, hash, &row);
batch_row_ids[warp_id][lane_id] = row; batch_row_ids[warp_id][lane_id] = row;
if (row == 0) { if (row == 0) {
const uint32_t batch_missing_idx = atomicAdd(batch_n_missing + warp_id, 1); const uint32_t batch_missing_idx = atomicAdd(batch_n_missing + warp_id, 1);
batch_missing_keys[warp_id][batch_missing_idx] = key; batch_missing_keys[warp_id][batch_missing_idx] = key;
batch_missing_indices[warp_id][batch_missing_idx] = key_offset; batch_missing_indices[warp_id][batch_missing_idx] = key_offset;
} }
} }
__syncthreads(); __syncthreads();
const uint32_t batch_n_missing_t = batch_n_missing[warp_id]; const uint32_t batch_n_missing_t = batch_n_missing[warp_id];
if (lane_id == 0) { if (lane_id == 0) {
const uint32_t old_n_missing = const uint32_t old_n_missing =
cuda::atomic::Add(n_missing, static_cast<uint32_t>(batch_n_missing_t)); cuda::atomic::Add(n_missing, static_cast<uint32_t>(batch_n_missing_t));
batch_n_missing[warp_id] = old_n_missing; batch_n_missing[warp_id] = old_n_missing;
} }
__syncthreads(); __syncthreads();
if (lane_id < batch_n_missing_t) { if (lane_id < batch_n_missing_t) {
missing_keys[batch_n_missing[warp_id] + lane_id] = batch_missing_keys[warp_id][lane_id]; missing_keys[batch_n_missing[warp_id] + lane_id] = batch_missing_keys[warp_id][lane_id];
missing_indices[batch_n_missing[warp_id] + lane_id] = batch_missing_indices[warp_id][lane_id]; missing_indices[batch_n_missing[warp_id] + lane_id] = batch_missing_indices[warp_id][lane_id];
} }
for (int i = 0; i < batch_n_key; ++i) { for (int i = 0; i < batch_n_key; ++i) {
const Key key = batch_keys[warp_id][i]; const Key key = batch_keys[warp_id][i];
const Index row = batch_row_ids[warp_id][i]; const Index row = batch_row_ids[warp_id][i];
if (row == 0) { continue; } if (row == 0) { continue; }
for (int col = lane_id; col < value_length; col += warp_size) { for (int col = lane_id; col < value_length; col += warp_size) {
values[(batch_start + i) * value_length + col] = values[(batch_start + i) * value_length + col] =
cache_values[(row - 1) * value_length + col]; cache_values[(row - 1) * value_length + col];
} }
} }
__syncthreads(); __syncthreads();
} }
} }
template<typename T, size_t pack_size> template<typename T, size_t pack_size>
struct alignas(sizeof(T) * pack_size) Pack { struct alignas(sizeof(T) * pack_size) Pack {
T elem[pack_size]; T elem[pack_size];
}; };
template<typename Key, typename Elem, typename Index, uint32_t block_size, uint32_t pack_size> template<typename Key, typename Elem, typename Index, uint32_t block_size, uint32_t pack_size>
__global__ void EncodeLookupMaskKernel(uint32_t value_length, const Elem* __restrict__ cache_values, __global__ void EncodeLookupMaskKernel(uint32_t value_length, const Elem* __restrict__ cache_values,
uint32_t values_elem_cnt, const Key* __restrict__ keys, uint32_t values_elem_cnt, const Key* __restrict__ keys,
const Index* __restrict__ context, Elem* __restrict__ values, const Index* __restrict__ context, Elem* __restrict__ values,
uint8_t* __restrict__ mask, const size_t capacity, uint8_t* __restrict__ mask, const size_t capacity,
Key* __restrict__ table_keys, Key* __restrict__ table_keys,
Index* __restrict__ table_indices) { Index* __restrict__ table_indices) {
const uint32_t packed_cols = value_length / pack_size; const uint32_t packed_cols = value_length / pack_size;
auto* packed_values = reinterpret_cast<Pack<Elem, pack_size>*>(values); auto* packed_values = reinterpret_cast<Pack<Elem, pack_size>*>(values);
const auto* packed_cache_values = reinterpret_cast<const Pack<Elem, pack_size>*>(cache_values); const auto* packed_cache_values = reinterpret_cast<const Pack<Elem, pack_size>*>(cache_values);
constexpr uint32_t warp_size = 32; constexpr uint32_t warp_size = 32;
constexpr uint32_t n_warp_per_block = block_size / warp_size; constexpr uint32_t n_warp_per_block = block_size / warp_size;
const uint32_t warp_id = threadIdx.x / warp_size; const uint32_t warp_id = threadIdx.x / warp_size;
const uint32_t lane_id = threadIdx.x % warp_size; const uint32_t lane_id = threadIdx.x % warp_size;
const uint32_t global_warp_id = blockIdx.x * n_warp_per_block + warp_id; const uint32_t global_warp_id = blockIdx.x * n_warp_per_block + warp_id;
const uint32_t global_n_warp = gridDim.x * n_warp_per_block; const uint32_t global_n_warp = gridDim.x * n_warp_per_block;
const uint32_t n_keys = values_elem_cnt / value_length; const uint32_t n_keys = values_elem_cnt / value_length;
__shared__ Key batch_keys[n_warp_per_block][warp_size]; __shared__ Key batch_keys[n_warp_per_block][warp_size];
__shared__ Index batch_row_ids[n_warp_per_block][warp_size]; __shared__ Index batch_row_ids[n_warp_per_block][warp_size];
for (uint32_t batch_start = global_warp_id * warp_size; batch_start < n_keys; for (uint32_t batch_start = global_warp_id * warp_size; batch_start < n_keys;
batch_start += global_n_warp * warp_size) { batch_start += global_n_warp * warp_size) {
const uint32_t batch_n_key = min(n_keys - batch_start, warp_size); const uint32_t batch_n_key = min(n_keys - batch_start, warp_size);
const uint32_t key_offset = batch_start + lane_id; const uint32_t key_offset = batch_start + lane_id;
if (key_offset < n_keys) { if (key_offset < n_keys) {
const Key key = keys[batch_start + lane_id]; const Key key = keys[batch_start + lane_id];
const uint64_t hash = FullCacheHash()(key); const uint64_t hash = FullCacheHash()(key);
Index row; Index row;
GetOne<Key, Index>(capacity, table_keys, table_indices, key, hash, &row); GetOne<Key, Index>(capacity, table_keys, table_indices, key, hash, &row);
batch_row_ids[warp_id][lane_id] = row; batch_row_ids[warp_id][lane_id] = row;
mask[key_offset] = row > 0; mask[key_offset] = row > 0;
} }
__syncthreads(); __syncthreads();
for (int i = 0; i < batch_n_key; ++i) { for (int i = 0; i < batch_n_key; ++i) {
const Key key = batch_keys[warp_id][i]; const Key key = batch_keys[warp_id][i];
const Index row = batch_row_ids[warp_id][i]; const Index row = batch_row_ids[warp_id][i];
if (row == 0) { continue; } if (row == 0) { continue; }
#pragma unroll 4 #pragma unroll 4
for (int col = lane_id; col < packed_cols; col += warp_size) { for (int col = lane_id; col < packed_cols; col += warp_size) {
packed_values[(batch_start + i) * packed_cols + col] = packed_values[(batch_start + i) * packed_cols + col] =
packed_cache_values[(row - 1) * packed_cols + col]; packed_cache_values[(row - 1) * packed_cols + col];
} }
} }
__syncthreads(); __syncthreads();
} }
} }
template<typename Elem, typename Index, size_t pack_size> template<typename Elem, typename Index, size_t pack_size>
__global__ void UpdateKernel(uint32_t value_length, Elem* cache_values, uint32_t values_elem_cnt, __global__ void UpdateKernel(uint32_t value_length, Elem* cache_values, uint32_t values_elem_cnt,
const Index* context, const Elem* values) { const Index* context, const Elem* values) {
const int packed_values_elem_cnt = values_elem_cnt / pack_size; const int packed_values_elem_cnt = values_elem_cnt / pack_size;
const uint32_t packed_elem_cnt = value_length / pack_size; const uint32_t packed_elem_cnt = value_length / pack_size;
auto* packed_cache_values = reinterpret_cast<Pack<Elem, pack_size>*>(cache_values); auto* packed_cache_values = reinterpret_cast<Pack<Elem, pack_size>*>(cache_values);
auto* packed_values = reinterpret_cast<const Pack<Elem, pack_size>*>(values); auto* packed_values = reinterpret_cast<const Pack<Elem, pack_size>*>(values);
CUDA_1D_KERNEL_LOOP(i, packed_values_elem_cnt) { CUDA_1D_KERNEL_LOOP(i, packed_values_elem_cnt) {
const uint64_t key_id = i / packed_elem_cnt; const uint64_t key_id = i / packed_elem_cnt;
const uint64_t ctx = context[key_id]; const uint64_t ctx = context[key_id];
if (ctx == 0) { continue; } if (ctx == 0) { continue; }
const uint64_t row_id = ctx - 1; const uint64_t row_id = ctx - 1;
const uint64_t col_id = i - key_id * packed_elem_cnt; const uint64_t col_id = i - key_id * packed_elem_cnt;
packed_cache_values[row_id * packed_elem_cnt + col_id] = packed_values[i]; packed_cache_values[row_id * packed_elem_cnt + col_id] = packed_values[i];
} }
} }
template<typename Elem, typename Index, size_t pack_size> template<typename Elem, typename Index, size_t pack_size>
__global__ typename std::enable_if<std::is_same<Elem, float>::value, void>::type __global__ typename std::enable_if<std::is_same<Elem, float>::value, void>::type
FusedHalfUpdateKernel(uint32_t value_length, Elem* __restrict__ cache_values, FusedHalfUpdateKernel(uint32_t value_length, Elem* __restrict__ cache_values,
uint32_t values_elem_cnt, const Index* __restrict__ context, uint32_t values_elem_cnt, const Index* __restrict__ context,
const Elem* __restrict__ values, const half* __restrict__ update, const Elem* __restrict__ values, const half* __restrict__ update,
const float* __restrict__ lr, float scale) { const float* __restrict__ lr, float scale) {
const int packed_values_elem_cnt = values_elem_cnt / pack_size; const int packed_values_elem_cnt = values_elem_cnt / pack_size;
const uint32_t packed_elem_cnt = value_length / pack_size; const uint32_t packed_elem_cnt = value_length / pack_size;
auto* packed_cache_values = reinterpret_cast<Pack<Elem, pack_size>*>(cache_values); auto* packed_cache_values = reinterpret_cast<Pack<Elem, pack_size>*>(cache_values);
auto* packed_values = reinterpret_cast<const Pack<Elem, pack_size>*>(values); auto* packed_values = reinterpret_cast<const Pack<Elem, pack_size>*>(values);
auto* packed_update = reinterpret_cast<const Pack<half, pack_size>*>(update); auto* packed_update = reinterpret_cast<const Pack<half, pack_size>*>(update);
const float alpha = -*lr * scale; const float alpha = -*lr * scale;
CUDA_1D_KERNEL_LOOP(i, packed_values_elem_cnt) { CUDA_1D_KERNEL_LOOP(i, packed_values_elem_cnt) {
const uint64_t key_id = i / packed_elem_cnt; const uint64_t key_id = i / packed_elem_cnt;
const uint64_t ctx = context[key_id]; const uint64_t ctx = context[key_id];
if (ctx == 0) { continue; } if (ctx == 0) { continue; }
const uint64_t row_id = ctx - 1; const uint64_t row_id = ctx - 1;
const uint64_t col_id = i - key_id * packed_elem_cnt; const uint64_t col_id = i - key_id * packed_elem_cnt;
Pack<Elem, pack_size> m = packed_values[i]; Pack<Elem, pack_size> m = packed_values[i];
Pack<half, pack_size> u = packed_update[i]; Pack<half, pack_size> u = packed_update[i];
for (size_t j = 0; j < pack_size; ++j) { m.elem[j] += static_cast<Elem>(u.elem[j]) * alpha; } for (size_t j = 0; j < pack_size; ++j) { m.elem[j] += static_cast<Elem>(u.elem[j]) * alpha; }
packed_cache_values[row_id * packed_elem_cnt + col_id] = m; packed_cache_values[row_id * packed_elem_cnt + col_id] = m;
} }
} }
template<typename Elem, typename Index, size_t pack_size> template<typename Elem, typename Index, size_t pack_size>
__global__ typename std::enable_if<!std::is_same<Elem, float>::value, void>::type __global__ typename std::enable_if<!std::is_same<Elem, float>::value, void>::type
FusedHalfUpdateKernel(uint32_t value_length, Elem* cache_values, uint32_t values_elem_cnt, FusedHalfUpdateKernel(uint32_t value_length, Elem* cache_values, uint32_t values_elem_cnt,
const Index* context, const Elem* values, const half* update, const float* lr, const Index* context, const Elem* values, const half* update, const float* lr,
float scale) { float scale) {
asm volatile("s_trap 0;"); asm volatile("s_trap 0;");
} }
template<typename Key, typename Elem, typename Index> template<typename Key, typename Elem, typename Index>
__global__ void DumpValueKernel(uint32_t value_length, const uint32_t* n_dumped, __global__ void DumpValueKernel(uint32_t value_length, const uint32_t* n_dumped,
const Index* context, const Elem* cache_values, Elem* values) { const Index* context, const Elem* cache_values, Elem* values) {
CUDA_1D_KERNEL_LOOP(i, *n_dumped * value_length) { CUDA_1D_KERNEL_LOOP(i, *n_dumped * value_length) {
const uint64_t key_id = i / value_length; const uint64_t key_id = i / value_length;
const uint64_t ctx = context[key_id]; const uint64_t ctx = context[key_id];
const uint64_t row_id = ctx - 1; const uint64_t row_id = ctx - 1;
const uint64_t col_id = i - key_id * value_length; const uint64_t col_id = i - key_id * value_length;
values[i] = cache_values[row_id * value_length + col_id]; values[i] = cache_values[row_id * value_length + col_id];
} }
} }
template<typename Key, typename Index> template<typename Key, typename Index>
class OrdinalEncoder { class OrdinalEncoder {
public: public:
OF_DISALLOW_COPY_AND_MOVE(OrdinalEncoder); OF_DISALLOW_COPY_AND_MOVE(OrdinalEncoder);
explicit OrdinalEncoder(uint64_t capacity, float load_factor) explicit OrdinalEncoder(uint64_t capacity, float load_factor)
: capacity_(capacity), table_capacity_(capacity / load_factor) { : capacity_(capacity), table_capacity_(capacity / load_factor) {
OF_CUDA_CHECK(hipGetDevice(&device_index_)); OF_CUDA_CHECK(hipGetDevice(&device_index_));
OF_CUDA_CHECK(hipMalloc(&table_size_, sizeof(Index))); OF_CUDA_CHECK(hipMalloc(&table_size_, sizeof(Index)));
OF_CUDA_CHECK(hipMallocHost(reinterpret_cast<void **>(&table_size_host_), sizeof(Index))); OF_CUDA_CHECK(hipMallocHost(reinterpret_cast<void **>(&table_size_host_), sizeof(Index)));
OF_CUDA_CHECK(hipMalloc(&table_keys_, table_capacity_ * sizeof(Key))); OF_CUDA_CHECK(hipMalloc(&table_keys_, table_capacity_ * sizeof(Key)));
OF_CUDA_CHECK(hipMalloc(&table_indices_, table_capacity_ * sizeof(Index))); OF_CUDA_CHECK(hipMalloc(&table_indices_, table_capacity_ * sizeof(Index)));
Clear(); Clear();
} }
~OrdinalEncoder() { ~OrdinalEncoder() {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
OF_CUDA_CHECK(hipFree(table_size_)); OF_CUDA_CHECK(hipFree(table_size_));
OF_CUDA_CHECK(hipHostFree(table_size_host_)); OF_CUDA_CHECK(hipHostFree(table_size_host_));
OF_CUDA_CHECK(hipFree(table_keys_)); OF_CUDA_CHECK(hipFree(table_keys_));
OF_CUDA_CHECK(hipFree(table_indices_)); OF_CUDA_CHECK(hipFree(table_indices_));
} }
template<bool insert> template<bool insert>
void Encode(ep::Stream* stream, uint32_t num_keys, const Key* keys, Index* context) { void Encode(ep::Stream* stream, uint32_t num_keys, const Key* keys, Index* context) {
if (insert) { if (insert) {
RUN_CUDA_KERNEL((OrdinalEncodeKernel<Key, Index>), stream, num_keys, table_capacity_, RUN_CUDA_KERNEL((OrdinalEncodeKernel<Key, Index>), stream, num_keys, table_capacity_,
table_keys_, table_indices_, table_size_, num_keys, keys, context); table_keys_, table_indices_, table_size_, num_keys, keys, context);
OF_CUDA_CHECK(hipMemcpyAsync(table_size_host_, table_size_, sizeof(Index), hipMemcpyDefault, OF_CUDA_CHECK(hipMemcpyAsync(table_size_host_, table_size_, sizeof(Index), hipMemcpyDefault,
stream->As<ep::CudaStream>()->cuda_stream())); stream->As<ep::CudaStream>()->cuda_stream()));
CHECK_JUST(stream->Sync()); CHECK_JUST(stream->Sync());
CHECK_LT(*table_size_host_, capacity_) CHECK_LT(*table_size_host_, capacity_)
<< "The number of key is larger than cache size, please enlarge cache_memory_budget. "; << "The number of key is larger than cache size, please enlarge cache_memory_budget. ";
} else { } else {
RUN_CUDA_KERNEL((OrdinalEncodeLookupKernel<Key, Index>), stream, num_keys, table_capacity_, RUN_CUDA_KERNEL((OrdinalEncodeLookupKernel<Key, Index>), stream, num_keys, table_capacity_,
table_keys_, table_indices_, num_keys, keys, context); table_keys_, table_indices_, num_keys, keys, context);
} }
} }
void Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index, void Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index,
uint32_t* n_dumped, Key* keys, Index* context) { uint32_t* n_dumped, Key* keys, Index* context) {
OF_CUDA_CHECK(hipMemsetAsync(n_dumped, 0, sizeof(uint32_t), OF_CUDA_CHECK(hipMemsetAsync(n_dumped, 0, sizeof(uint32_t),
stream->As<ep::CudaStream>()->cuda_stream())); stream->As<ep::CudaStream>()->cuda_stream()));
RUN_CUDA_KERNEL((OrdinalEncodeDumpKernel<Key, Index>), stream, end_key_index - start_key_index, RUN_CUDA_KERNEL((OrdinalEncodeDumpKernel<Key, Index>), stream, end_key_index - start_key_index,
table_keys_, table_indices_, start_key_index, end_key_index, n_dumped, keys, table_keys_, table_indices_, start_key_index, end_key_index, n_dumped, keys,
context); context);
} }
void Clear() { void Clear() {
OF_CUDA_CHECK(hipMemset(table_size_, 0, sizeof(Index))); OF_CUDA_CHECK(hipMemset(table_size_, 0, sizeof(Index)));
OF_CUDA_CHECK(hipMemset(table_keys_, 0, table_capacity_ * sizeof(Key))); OF_CUDA_CHECK(hipMemset(table_keys_, 0, table_capacity_ * sizeof(Key)));
OF_CUDA_CHECK(hipMemset(table_indices_, 0, table_capacity_ * sizeof(Index))); OF_CUDA_CHECK(hipMemset(table_indices_, 0, table_capacity_ * sizeof(Index)));
} }
uint64_t TableCapacity() const { return table_capacity_; } uint64_t TableCapacity() const { return table_capacity_; }
Key* table_keys() const { return table_keys_; } Key* table_keys() const { return table_keys_; }
Index* table_indices() const { return table_indices_; } Index* table_indices() const { return table_indices_; }
private: private:
int device_index_{}; int device_index_{};
Key* table_keys_; Key* table_keys_;
Index* table_indices_; Index* table_indices_;
uint64_t capacity_; uint64_t capacity_;
uint64_t table_capacity_; uint64_t table_capacity_;
Index* table_size_{}; Index* table_size_{};
Index* table_size_host_{}; Index* table_size_host_{};
}; };
template<typename Key, typename Elem, typename Index, size_t pack_size> template<typename Key, typename Elem, typename Index, size_t pack_size>
class CacheImpl : public Cache { class CacheImpl : public Cache {
public: public:
OF_DISALLOW_COPY_AND_MOVE(CacheImpl); OF_DISALLOW_COPY_AND_MOVE(CacheImpl);
explicit CacheImpl(const CacheOptions& options) explicit CacheImpl(const CacheOptions& options)
: encoder_(options.capacity, options.load_factor), : encoder_(options.capacity, options.load_factor),
device_index_(-1), device_index_(-1),
options_(options), options_(options),
max_query_length_(0) { max_query_length_(0) {
OF_CUDA_CHECK(hipGetDevice(&device_index_)); OF_CUDA_CHECK(hipGetDevice(&device_index_));
const uint64_t values_size = options.capacity * options.value_size; const uint64_t values_size = options.capacity * options.value_size;
if (options.value_memory_kind == CacheOptions::MemoryKind::kDevice) { if (options.value_memory_kind == CacheOptions::MemoryKind::kDevice) {
OF_CUDA_CHECK(hipMalloc(&values_, values_size)); OF_CUDA_CHECK(hipMalloc(&values_, values_size));
} else if (options.value_memory_kind == CacheOptions::MemoryKind::kHost) { } else if (options.value_memory_kind == CacheOptions::MemoryKind::kHost) {
if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_DISABLE_NUMA_AWARE_ALLOCATION", false)) { if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_DISABLE_NUMA_AWARE_ALLOCATION", false)) {
OF_CUDA_CHECK(hipMallocHost(reinterpret_cast<void **>(&values_), values_size)); OF_CUDA_CHECK(hipMallocHost(reinterpret_cast<void **>(&values_), values_size));
} else { } else {
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, reinterpret_cast<void**>(&values_), OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, reinterpret_cast<void**>(&values_),
values_size)); values_size));
} }
} else { } else {
UNIMPLEMENTED(); UNIMPLEMENTED();
} }
num_elem_per_value_ = options_.value_size / sizeof(Elem); num_elem_per_value_ = options_.value_size / sizeof(Elem);
} }
~CacheImpl() { ~CacheImpl() {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
if (options_.value_memory_kind == CacheOptions::MemoryKind::kDevice) { if (options_.value_memory_kind == CacheOptions::MemoryKind::kDevice) {
OF_CUDA_CHECK(hipFree(values_)); OF_CUDA_CHECK(hipFree(values_));
} else if (options_.value_memory_kind == CacheOptions::MemoryKind::kHost) { } else if (options_.value_memory_kind == CacheOptions::MemoryKind::kHost) {
OF_CUDA_CHECK(hipHostFree(values_)); OF_CUDA_CHECK(hipHostFree(values_));
} else { } else {
UNIMPLEMENTED(); UNIMPLEMENTED();
} }
if (max_query_length_ > 0) { OF_CUDA_CHECK(hipFree(encoding_buffer_)); } if (max_query_length_ > 0) { OF_CUDA_CHECK(hipFree(encoding_buffer_)); }
} }
uint64_t Capacity() const override { return options_.capacity; } uint64_t Capacity() const override { return options_.capacity; }
uint64_t DumpCapacity() const override { return encoder_.TableCapacity(); } uint64_t DumpCapacity() const override { return encoder_.TableCapacity(); }
uint32_t KeySize() const override { return options_.key_size; } uint32_t KeySize() const override { return options_.key_size; }
uint32_t ValueSize() const override { return options_.value_size; } uint32_t ValueSize() const override { return options_.value_size; }
DataType ValueType() const override { return options_.value_type; } DataType ValueType() const override { return options_.value_type; }
uint32_t MaxQueryLength() const override { return max_query_length_; } uint32_t MaxQueryLength() const override { return max_query_length_; }
void ReserveQueryLength(uint32_t query_length) override { void ReserveQueryLength(uint32_t query_length) override {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
if (query_length <= max_query_length_) { return; } if (query_length <= max_query_length_) { return; }
if (max_query_length_ > 0) { OF_CUDA_CHECK(hipFree(encoding_buffer_)); } if (max_query_length_ > 0) { OF_CUDA_CHECK(hipFree(encoding_buffer_)); }
OF_CUDA_CHECK(hipMalloc(&encoding_buffer_, query_length * sizeof(uint64_t))); OF_CUDA_CHECK(hipMalloc(&encoding_buffer_, query_length * sizeof(uint64_t)));
max_query_length_ = query_length; max_query_length_ = query_length;
} }
CacheOptions::Policy Policy() const override { return CacheOptions::Policy::kFull; } CacheOptions::Policy Policy() const override { return CacheOptions::Policy::kFull; }
void Test(ep::Stream* stream, uint32_t n_keys, const void* keys, uint32_t* n_missing, void Test(ep::Stream* stream, uint32_t n_keys, const void* keys, uint32_t* n_missing,
void* missing_keys, uint32_t* missing_indices) override; void* missing_keys, uint32_t* missing_indices) override;
void Get(ep::Stream* stream, uint32_t n_keys, const void* keys, void* values, uint32_t* n_missing, void Get(ep::Stream* stream, uint32_t n_keys, const void* keys, void* values, uint32_t* n_missing,
void* missing_keys, uint32_t* missing_indices) override; void* missing_keys, uint32_t* missing_indices) override;
void Get(ep::Stream* stream, uint32_t n_keys, const void* keys, void* values, void Get(ep::Stream* stream, uint32_t n_keys, const void* keys, void* values,
uint8_t* mask) override; uint8_t* mask) override;
void Put(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values, void Put(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values,
uint32_t* n_evicted, void* evicted_keys, void* evicted_values) override; uint32_t* n_evicted, void* evicted_keys, void* evicted_values) override;
void FusedHalfUpdatePut(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values, void FusedHalfUpdatePut(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values,
const void* update, const float* lr, float scale, uint32_t* n_evicted, const void* update, const float* lr, float scale, uint32_t* n_evicted,
void* evicted_keys, void* evicted_values) override; void* evicted_keys, void* evicted_values) override;
void Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index, void Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index,
uint32_t* n_dumped, void* keys, void* values) override; uint32_t* n_dumped, void* keys, void* values) override;
void Clear() override; void Clear() override;
private: private:
OrdinalEncoder<Key, Index> encoder_; OrdinalEncoder<Key, Index> encoder_;
int device_index_; int device_index_;
uint32_t num_elem_per_value_{}; uint32_t num_elem_per_value_{};
Elem* values_; Elem* values_;
Index* encoding_buffer_{}; Index* encoding_buffer_{};
CacheOptions options_; CacheOptions options_;
uint32_t max_query_length_; uint32_t max_query_length_;
}; };
template<typename Key, typename Elem, typename Index, size_t pack_size> template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::Test(ep::Stream* stream, uint32_t n_keys, void CacheImpl<Key, Elem, Index, pack_size>::Test(ep::Stream* stream, uint32_t n_keys,
const void* keys, uint32_t* n_missing, const void* keys, uint32_t* n_missing,
void* missing_keys, uint32_t* missing_indices) { void* missing_keys, uint32_t* missing_indices) {
OF_CUDA_CHECK( OF_CUDA_CHECK(
hipMemsetAsync(n_missing, 0, sizeof(uint32_t), stream->As<ep::CudaStream>()->cuda_stream())); hipMemsetAsync(n_missing, 0, sizeof(uint32_t), stream->As<ep::CudaStream>()->cuda_stream()));
if (n_keys == 0) { return; } if (n_keys == 0) { return; }
CHECK_LE(n_keys, max_query_length_); CHECK_LE(n_keys, max_query_length_);
encoder_.template Encode<false>(stream, n_keys, static_cast<const Key*>(keys), encoding_buffer_); encoder_.template Encode<false>(stream, n_keys, static_cast<const Key*>(keys), encoding_buffer_);
const uint32_t values_elem_cnt = n_keys * num_elem_per_value_; const uint32_t values_elem_cnt = n_keys * num_elem_per_value_;
RUN_CUDA_KERNEL((LookupKernel<Key, Elem, Index, false>), stream, values_elem_cnt, RUN_CUDA_KERNEL((LookupKernel<Key, Elem, Index, false>), stream, values_elem_cnt,
num_elem_per_value_, values_, values_elem_cnt, static_cast<const Key*>(keys), num_elem_per_value_, values_, values_elem_cnt, static_cast<const Key*>(keys),
encoding_buffer_, nullptr, n_missing, static_cast<Key*>(missing_keys), encoding_buffer_, nullptr, n_missing, static_cast<Key*>(missing_keys),
missing_indices); missing_indices);
} }
template<typename Key, typename Elem, typename Index, size_t pack_size> template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::Get(ep::Stream* stream, uint32_t n_keys, void CacheImpl<Key, Elem, Index, pack_size>::Get(ep::Stream* stream, uint32_t n_keys,
const void* keys, void* values, const void* keys, void* values,
uint32_t* n_missing, void* missing_keys, uint32_t* n_missing, void* missing_keys,
uint32_t* missing_indices) { uint32_t* missing_indices) {
OF_CUDA_CHECK( OF_CUDA_CHECK(
hipMemsetAsync(n_missing, 0, sizeof(uint32_t), stream->As<ep::CudaStream>()->cuda_stream())); hipMemsetAsync(n_missing, 0, sizeof(uint32_t), stream->As<ep::CudaStream>()->cuda_stream()));
if (n_keys == 0) { return; } if (n_keys == 0) { return; }
CHECK_LE(n_keys, max_query_length_); CHECK_LE(n_keys, max_query_length_);
constexpr uint32_t block_size = 128; constexpr uint32_t block_size = 128;
uint32_t grid_size = (n_keys + block_size - 1) / block_size; uint32_t grid_size = (n_keys + block_size - 1) / block_size;
const uint32_t values_elem_cnt = n_keys * num_elem_per_value_; const uint32_t values_elem_cnt = n_keys * num_elem_per_value_;
EncodeLookupKernel<Key, Elem, Index, block_size> EncodeLookupKernel<Key, Elem, Index, block_size>
<<<grid_size, block_size, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>( <<<grid_size, block_size, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
num_elem_per_value_, values_, values_elem_cnt, static_cast<const Key*>(keys), num_elem_per_value_, values_, values_elem_cnt, static_cast<const Key*>(keys),
encoding_buffer_, static_cast<Elem*>(values), n_missing, static_cast<Key*>(missing_keys), encoding_buffer_, static_cast<Elem*>(values), n_missing, static_cast<Key*>(missing_keys),
missing_indices, encoder_.TableCapacity(), encoder_.table_keys(), missing_indices, encoder_.TableCapacity(), encoder_.table_keys(),
encoder_.table_indices()); encoder_.table_indices());
} }
template<typename Key, typename Elem, typename Index, size_t pack_size> template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::Get(ep::Stream* stream, uint32_t n_keys, void CacheImpl<Key, Elem, Index, pack_size>::Get(ep::Stream* stream, uint32_t n_keys,
const void* keys, void* values, uint8_t* mask) { const void* keys, void* values, uint8_t* mask) {
if (n_keys == 0) { return; } if (n_keys == 0) { return; }
CHECK_LE(n_keys, max_query_length_); CHECK_LE(n_keys, max_query_length_);
constexpr uint32_t block_size = 128; constexpr uint32_t block_size = 128;
uint32_t grid_size = (n_keys + block_size - 1) / block_size; uint32_t grid_size = (n_keys + block_size - 1) / block_size;
const uint32_t values_elem_cnt = n_keys * num_elem_per_value_; const uint32_t values_elem_cnt = n_keys * num_elem_per_value_;
EncodeLookupMaskKernel<Key, Elem, Index, block_size, pack_size> EncodeLookupMaskKernel<Key, Elem, Index, block_size, pack_size>
<<<grid_size, block_size, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>( <<<grid_size, block_size, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
num_elem_per_value_, values_, values_elem_cnt, static_cast<const Key*>(keys), num_elem_per_value_, values_, values_elem_cnt, static_cast<const Key*>(keys),
encoding_buffer_, static_cast<Elem*>(values), mask, encoder_.TableCapacity(), encoding_buffer_, static_cast<Elem*>(values), mask, encoder_.TableCapacity(),
encoder_.table_keys(), encoder_.table_indices()); encoder_.table_keys(), encoder_.table_indices());
} }
template<typename Key, typename Elem, typename Index, size_t pack_size> template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::Put(ep::Stream* stream, uint32_t n_keys, void CacheImpl<Key, Elem, Index, pack_size>::Put(ep::Stream* stream, uint32_t n_keys,
const void* keys, const void* values, const void* keys, const void* values,
uint32_t* n_evicted, void* evicted_keys, uint32_t* n_evicted, void* evicted_keys,
void* evicted_values) { void* evicted_values) {
OF_CUDA_CHECK( OF_CUDA_CHECK(
hipMemsetAsync(n_evicted, 0, sizeof(uint32_t), stream->As<ep::CudaStream>()->cuda_stream())); hipMemsetAsync(n_evicted, 0, sizeof(uint32_t), stream->As<ep::CudaStream>()->cuda_stream()));
if (n_keys == 0) { return; } if (n_keys == 0) { return; }
CHECK_LE(n_keys, max_query_length_); CHECK_LE(n_keys, max_query_length_);
encoder_.template Encode<true>(stream, n_keys, static_cast<const Key*>(keys), encoding_buffer_); encoder_.template Encode<true>(stream, n_keys, static_cast<const Key*>(keys), encoding_buffer_);
const uint32_t values_elem_cnt = n_keys * num_elem_per_value_; const uint32_t values_elem_cnt = n_keys * num_elem_per_value_;
RUN_CUDA_KERNEL((UpdateKernel<Elem, Index, pack_size>), stream, values_elem_cnt / pack_size, RUN_CUDA_KERNEL((UpdateKernel<Elem, Index, pack_size>), stream, values_elem_cnt / pack_size,
num_elem_per_value_, values_, values_elem_cnt, encoding_buffer_, num_elem_per_value_, values_, values_elem_cnt, encoding_buffer_,
static_cast<const Elem*>(values)); static_cast<const Elem*>(values));
} }
template<typename Key, typename Elem, typename Index, size_t pack_size> template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::FusedHalfUpdatePut( void CacheImpl<Key, Elem, Index, pack_size>::FusedHalfUpdatePut(
ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values, const void* update, ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values, const void* update,
const float* lr, float scale, uint32_t* n_evicted, void* evicted_keys, void* evicted_values) { const float* lr, float scale, uint32_t* n_evicted, void* evicted_keys, void* evicted_values) {
if (!std::is_same<Elem, float>::value) { UNIMPLEMENTED(); } if (!std::is_same<Elem, float>::value) { UNIMPLEMENTED(); }
OF_CUDA_CHECK( OF_CUDA_CHECK(
hipMemsetAsync(n_evicted, 0, sizeof(uint32_t), stream->As<ep::CudaStream>()->cuda_stream())); hipMemsetAsync(n_evicted, 0, sizeof(uint32_t), stream->As<ep::CudaStream>()->cuda_stream()));
if (n_keys == 0) { return; } if (n_keys == 0) { return; }
CHECK_LE(n_keys, max_query_length_); CHECK_LE(n_keys, max_query_length_);
encoder_.template Encode<true>(stream, n_keys, static_cast<const Key*>(keys), encoding_buffer_); encoder_.template Encode<true>(stream, n_keys, static_cast<const Key*>(keys), encoding_buffer_);
const uint32_t values_elem_cnt = n_keys * num_elem_per_value_; const uint32_t values_elem_cnt = n_keys * num_elem_per_value_;
RUN_CUDA_KERNEL((FusedHalfUpdateKernel<Elem, Index, pack_size>), stream, RUN_CUDA_KERNEL((FusedHalfUpdateKernel<Elem, Index, pack_size>), stream,
values_elem_cnt / pack_size, num_elem_per_value_, values_, values_elem_cnt, values_elem_cnt / pack_size, num_elem_per_value_, values_, values_elem_cnt,
encoding_buffer_, static_cast<const Elem*>(values), encoding_buffer_, static_cast<const Elem*>(values),
static_cast<const half*>(update), lr, scale); static_cast<const half*>(update), lr, scale);
} }
template<typename Key, typename Elem, typename Index, size_t pack_size> template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::Dump(ep::Stream* stream, uint64_t start_key_index, void CacheImpl<Key, Elem, Index, pack_size>::Dump(ep::Stream* stream, uint64_t start_key_index,
uint64_t end_key_index, uint32_t* n_dumped, uint64_t end_key_index, uint32_t* n_dumped,
void* keys, void* values) { void* keys, void* values) {
encoder_.Dump(stream, start_key_index, end_key_index, n_dumped, static_cast<Key*>(keys), encoder_.Dump(stream, start_key_index, end_key_index, n_dumped, static_cast<Key*>(keys),
encoding_buffer_); encoding_buffer_);
RUN_CUDA_KERNEL((DumpValueKernel<Key, Elem, Index>), stream, RUN_CUDA_KERNEL((DumpValueKernel<Key, Elem, Index>), stream,
num_elem_per_value_ * (end_key_index - start_key_index), num_elem_per_value_, num_elem_per_value_ * (end_key_index - start_key_index), num_elem_per_value_,
n_dumped, encoding_buffer_, values_, static_cast<Elem*>(values)); n_dumped, encoding_buffer_, values_, static_cast<Elem*>(values));
} }
template<typename Key, typename Elem, typename Index, size_t pack_size> template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::Clear() { void CacheImpl<Key, Elem, Index, pack_size>::Clear() {
encoder_.Clear(); encoder_.Clear();
} }
template<typename Key, typename Index> template<typename Key, typename Index>
std::unique_ptr<Cache> DispatchValueType(const CacheOptions& options) { std::unique_ptr<Cache> DispatchValueType(const CacheOptions& options) {
if (options.value_type == DataType::kFloat) { if (options.value_type == DataType::kFloat) {
const size_t value_elem_cnt = options.value_size / sizeof(float); const size_t value_elem_cnt = options.value_size / sizeof(float);
const size_t half_warp = 16; const size_t half_warp = 16;
if (value_elem_cnt % 4 == 0 && value_elem_cnt / 4 > half_warp) { if (value_elem_cnt % 4 == 0 && value_elem_cnt / 4 > half_warp) {
return std::unique_ptr<Cache>(new CacheImpl<Key, float, Index, 4>(options)); return std::unique_ptr<Cache>(new CacheImpl<Key, float, Index, 4>(options));
} else if (value_elem_cnt % 2 == 0 && value_elem_cnt / 2 > half_warp) { } else if (value_elem_cnt % 2 == 0 && value_elem_cnt / 2 > half_warp) {
return std::unique_ptr<Cache>(new CacheImpl<Key, float, Index, 2>(options)); return std::unique_ptr<Cache>(new CacheImpl<Key, float, Index, 2>(options));
} else { } else {
return std::unique_ptr<Cache>(new CacheImpl<Key, float, Index, 1>(options)); return std::unique_ptr<Cache>(new CacheImpl<Key, float, Index, 1>(options));
} }
} else if (options.value_size % sizeof(ulonglong2) == 0) { } else if (options.value_size % sizeof(ulonglong2) == 0) {
return std::unique_ptr<Cache>(new CacheImpl<Key, ulonglong2, Index, 1>(options)); return std::unique_ptr<Cache>(new CacheImpl<Key, ulonglong2, Index, 1>(options));
} else if (options.value_size % sizeof(uint64_t) == 0) { } else if (options.value_size % sizeof(uint64_t) == 0) {
return std::unique_ptr<Cache>(new CacheImpl<Key, uint64_t, Index, 1>(options)); return std::unique_ptr<Cache>(new CacheImpl<Key, uint64_t, Index, 1>(options));
} else if (options.value_size % sizeof(uint32_t) == 0) { } else if (options.value_size % sizeof(uint32_t) == 0) {
return std::unique_ptr<Cache>(new CacheImpl<Key, uint32_t, Index, 1>(options)); return std::unique_ptr<Cache>(new CacheImpl<Key, uint32_t, Index, 1>(options));
} else if (options.value_size % sizeof(uint16_t) == 0) { } else if (options.value_size % sizeof(uint16_t) == 0) {
return std::unique_ptr<Cache>(new CacheImpl<Key, uint16_t, Index, 1>(options)); return std::unique_ptr<Cache>(new CacheImpl<Key, uint16_t, Index, 1>(options));
} else { } else {
return std::unique_ptr<Cache>(new CacheImpl<Key, uint8_t, Index, 1>(options)); return std::unique_ptr<Cache>(new CacheImpl<Key, uint8_t, Index, 1>(options));
} }
} }
template<typename Index> template<typename Index>
std::unique_ptr<Cache> DispatchKeyType(const CacheOptions& options) { std::unique_ptr<Cache> DispatchKeyType(const CacheOptions& options) {
if (options.key_size == sizeof(Key32)) { if (options.key_size == sizeof(Key32)) {
return DispatchValueType<Key32, Index>(options); return DispatchValueType<Key32, Index>(options);
} else if (options.key_size == sizeof(Key64)) { } else if (options.key_size == sizeof(Key64)) {
return DispatchValueType<Key64, Index>(options); return DispatchValueType<Key64, Index>(options);
} else { } else {
UNIMPLEMENTED(); UNIMPLEMENTED();
return nullptr; return nullptr;
} }
} }
std::unique_ptr<Cache> DispatchIndexType(const CacheOptions& options) { std::unique_ptr<Cache> DispatchIndexType(const CacheOptions& options) {
const int64_t table_capacity = static_cast<double>(options.capacity) / options.load_factor; const int64_t table_capacity = static_cast<double>(options.capacity) / options.load_factor;
if (table_capacity >= (1ULL << 31ULL)) { if (table_capacity >= (1ULL << 31ULL)) {
return DispatchKeyType<uint64_t>(options); return DispatchKeyType<uint64_t>(options);
} else { } else {
return DispatchKeyType<uint32_t>(options); return DispatchKeyType<uint32_t>(options);
} }
} }
} // namespace } // namespace
std::unique_ptr<Cache> NewFullCache(const CacheOptions& options) { std::unique_ptr<Cache> NewFullCache(const CacheOptions& options) {
return DispatchIndexType(options); return DispatchIndexType(options);
} }
} // namespace embedding } // namespace embedding
} // namespace oneflow } // namespace oneflow
\ No newline at end of file
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#ifndef ONEFLOW_CORE_EMBEDDING_HASH_FUNCTION_HIP_H_ #ifndef ONEFLOW_CORE_EMBEDDING_HASH_FUNCTION_HIP_H_
#define ONEFLOW_CORE_EMBEDDING_HASH_FUNCTION_HIP_H_ #define ONEFLOW_CORE_EMBEDDING_HASH_FUNCTION_HIP_H_
#include <stdint.h> #include <stdint.h>
#include "oneflow/core/common/data_type.h" #include "oneflow/core/common/data_type.h"
namespace oneflow { namespace oneflow {
namespace embedding { namespace embedding {
namespace { namespace {
// From https://github.com/Cyan4973/xxHash/blob/dev/xxhash.h // From https://github.com/Cyan4973/xxHash/blob/dev/xxhash.h
static const uint64_t PRIME64_1 = static const uint64_t PRIME64_1 =
0x9E3779B185EBCA87ULL; // 0b1001111000110111011110011011000110000101111010111100101010000111 0x9E3779B185EBCA87ULL; // 0b1001111000110111011110011011000110000101111010111100101010000111
static const uint64_t PRIME64_2 = static const uint64_t PRIME64_2 =
0xC2B2AE3D27D4EB4FULL; // 0b1100001010110010101011100011110100100111110101001110101101001111 0xC2B2AE3D27D4EB4FULL; // 0b1100001010110010101011100011110100100111110101001110101101001111
static const uint64_t PRIME64_3 = static const uint64_t PRIME64_3 =
0x165667B19E3779F9ULL; // 0b0001011001010110011001111011000110011110001101110111100111111001 0x165667B19E3779F9ULL; // 0b0001011001010110011001111011000110011110001101110111100111111001
static const uint64_t PRIME64_4 = static const uint64_t PRIME64_4 =
0x85EBCA77C2B2AE63ULL; // 0b1000010111101011110010100111011111000010101100101010111001100011 0x85EBCA77C2B2AE63ULL; // 0b1000010111101011110010100111011111000010101100101010111001100011
static const uint64_t PRIME64_5 = static const uint64_t PRIME64_5 =
0x27D4EB2F165667C5ULL; // 0b0010011111010100111010110010111100010110010101100110011111000101 0x27D4EB2F165667C5ULL; // 0b0010011111010100111010110010111100010110010101100110011111000101
#define XXH_rotl64(x, r) (((x) << (r)) | ((x) >> (64 - (r)))) #define XXH_rotl64(x, r) (((x) << (r)) | ((x) >> (64 - (r))))
OF_DEVICE_FUNC uint64_t XXH64_round(uint64_t acc, uint64_t input) { OF_DEVICE_FUNC uint64_t XXH64_round(uint64_t acc, uint64_t input) {
acc += input * PRIME64_2; acc += input * PRIME64_2;
acc = XXH_rotl64(acc, 31); acc = XXH_rotl64(acc, 31);
acc *= PRIME64_1; acc *= PRIME64_1;
return acc; return acc;
} }
OF_DEVICE_FUNC uint64_t xxh64_uint64(uint64_t v, uint64_t seed) { OF_DEVICE_FUNC uint64_t xxh64_uint64(uint64_t v, uint64_t seed) {
uint64_t acc = seed + PRIME64_5; uint64_t acc = seed + PRIME64_5;
acc += sizeof(uint64_t); acc += sizeof(uint64_t);
acc = acc ^ XXH64_round(0, v); acc = acc ^ XXH64_round(0, v);
acc = XXH_rotl64(acc, 27) * PRIME64_1; acc = XXH_rotl64(acc, 27) * PRIME64_1;
acc = acc + PRIME64_4; acc = acc + PRIME64_4;
acc ^= (acc >> 33); acc ^= (acc >> 33);
acc = acc * PRIME64_2; acc = acc * PRIME64_2;
acc = acc ^ (acc >> 29); acc = acc ^ (acc >> 29);
acc = acc * PRIME64_3; acc = acc * PRIME64_3;
acc = acc ^ (acc >> 32); acc = acc ^ (acc >> 32);
return acc; return acc;
} }
static const size_t kShardingHashSeed = 1; static const size_t kShardingHashSeed = 1;
static const size_t kLocalUniqueHashSeed = 2; static const size_t kLocalUniqueHashSeed = 2;
static const size_t kGlobalUniqueHashSeed = 3; static const size_t kGlobalUniqueHashSeed = 3;
static const size_t kFullCacheHashSeed = 4; static const size_t kFullCacheHashSeed = 4;
static const size_t kLruCacheHashSeed = 5; static const size_t kLruCacheHashSeed = 5;
} // namespace } // namespace
struct ShardingHash { struct ShardingHash {
OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kShardingHashSeed); } OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kShardingHashSeed); }
OF_DEVICE_FUNC size_t operator()(uint32_t v) { return xxh64_uint64(v, kShardingHashSeed); } OF_DEVICE_FUNC size_t operator()(uint32_t v) { return xxh64_uint64(v, kShardingHashSeed); }
OF_DEVICE_FUNC size_t operator()(int32_t v) { OF_DEVICE_FUNC size_t operator()(int32_t v) {
return xxh64_uint64(static_cast<uint32_t>(v), kShardingHashSeed); return xxh64_uint64(static_cast<uint32_t>(v), kShardingHashSeed);
} }
OF_DEVICE_FUNC size_t operator()(int64_t v) { OF_DEVICE_FUNC size_t operator()(int64_t v) {
return xxh64_uint64(static_cast<uint64_t>(v), kShardingHashSeed); return xxh64_uint64(static_cast<uint64_t>(v), kShardingHashSeed);
} }
}; };
struct LocalUniqueHash { struct LocalUniqueHash {
OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kLocalUniqueHashSeed); } OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kLocalUniqueHashSeed); }
}; };
struct GlobalUniqueHash { struct GlobalUniqueHash {
OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kGlobalUniqueHashSeed); } OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kGlobalUniqueHashSeed); }
}; };
struct FullCacheHash { struct FullCacheHash {
OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kFullCacheHashSeed); } OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kFullCacheHashSeed); }
}; };
struct LruCacheHash { struct LruCacheHash {
OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kLruCacheHashSeed); } OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kLruCacheHashSeed); }
}; };
} // namespace embedding } // namespace embedding
} // namespace oneflow } // namespace oneflow
#endif // ONEFLOW_CORE_EMBEDDING_HASH_FUNCTION_HIP_H_ #endif // ONEFLOW_CORE_EMBEDDING_HASH_FUNCTION_HIP_H_
\ No newline at end of file
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
// Inspired by https://github.com/NVIDIA-Merlin/HugeCTR/blob/master/gpu_cache/src/nv_gpu_cache.cu // Inspired by https://github.com/NVIDIA-Merlin/HugeCTR/blob/master/gpu_cache/src/nv_gpu_cache.cu
#include "oneflow/core/embedding/lru_cache.h" #include "oneflow/core/embedding/lru_cache.h"
#include "oneflow/core/device/cuda_util.h" #include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/embedding/hash_functions.hip.h" #include "oneflow/core/embedding/hash_functions.hip.h"
#include <new> #include <new>
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
namespace oneflow { namespace oneflow {
namespace embedding { namespace embedding {
namespace { namespace {
constexpr int kWarpSize = 64; constexpr int kWarpSize = 64;
constexpr int kNumWarpPerBlock = 2; constexpr int kNumWarpPerBlock = 2;
constexpr int kBlockSize = kNumWarpPerBlock * kWarpSize; constexpr int kBlockSize = kNumWarpPerBlock * kWarpSize;
constexpr unsigned long long int kFullMask = 0xFFFFFFFFFFFFFFFFU; constexpr unsigned long long int kFullMask = 0xFFFFFFFFFFFFFFFFU;
ep::CudaLaunchConfig GetLaunchConfig(uint32_t n_keys) { ep::CudaLaunchConfig GetLaunchConfig(uint32_t n_keys) {
return ep::CudaLaunchConfig((n_keys + kNumWarpPerBlock - 1) / kNumWarpPerBlock, return ep::CudaLaunchConfig((n_keys + kNumWarpPerBlock - 1) / kNumWarpPerBlock,
kWarpSize * kNumWarpPerBlock, 0); kWarpSize * kNumWarpPerBlock, 0);
} }
struct ThreadContext { struct ThreadContext {
__device__ ThreadContext() { __device__ ThreadContext() {
const uint32_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x; const uint32_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;
global_warp_id = global_thread_id / kWarpSize; global_warp_id = global_thread_id / kWarpSize;
warp_id_in_block = global_warp_id % kNumWarpPerBlock; // NOLINT warp_id_in_block = global_warp_id % kNumWarpPerBlock; // NOLINT
num_warps = gridDim.x * kNumWarpPerBlock; // NOLINT num_warps = gridDim.x * kNumWarpPerBlock; // NOLINT
lane_id = global_thread_id % kWarpSize; lane_id = global_thread_id % kWarpSize;
} }
uint32_t global_warp_id; uint32_t global_warp_id;
uint32_t warp_id_in_block; uint32_t warp_id_in_block;
uint32_t num_warps; uint32_t num_warps;
uint32_t lane_id; uint32_t lane_id;
}; };
class WarpMutexAtomicImpl { class WarpMutexAtomicImpl {
public: public:
OF_DISALLOW_COPY_AND_MOVE(WarpMutexAtomicImpl); OF_DISALLOW_COPY_AND_MOVE(WarpMutexAtomicImpl);
__device__ WarpMutexAtomicImpl() : flag_(0) {} __device__ WarpMutexAtomicImpl() : flag_(0) {}
__device__ ~WarpMutexAtomicImpl() = default; __device__ ~WarpMutexAtomicImpl() = default;
__device__ void Lock(const ThreadContext& thread_ctx) { __device__ void Lock(const ThreadContext& thread_ctx) {
if (thread_ctx.lane_id == 0) { if (thread_ctx.lane_id == 0) {
while (atomicCAS(&flag_, 0, 1) != 0) while (atomicCAS(&flag_, 0, 1) != 0)
; ;
} }
__threadfence(); __threadfence();
__syncthreads(); __syncthreads();
} }
__device__ void Unlock(const ThreadContext& thread_ctx) { __device__ void Unlock(const ThreadContext& thread_ctx) {
__syncthreads(); __syncthreads();
__threadfence(); __threadfence();
if (thread_ctx.lane_id == 0) { atomicExch(&flag_, 0); } if (thread_ctx.lane_id == 0) { atomicExch(&flag_, 0); }
} }
private: private:
int32_t flag_; int32_t flag_;
}; };
template<typename Key, typename Elem> template<typename Key, typename Elem>
struct LruCacheContext { struct LruCacheContext {
Key* keys; Key* keys;
Elem* lines; Elem* lines;
uint8_t* ages; uint8_t* ages;
void* mutex; void* mutex;
uint64_t n_set; uint64_t n_set;
uint32_t line_size; uint32_t line_size;
CacheOptions::MemoryKind value_memory_kind; CacheOptions::MemoryKind value_memory_kind;
}; };
__global__ void InitCacheSetMutex(uint32_t n_set, void* mutex) { __global__ void InitCacheSetMutex(uint32_t n_set, void* mutex) {
using WarpMutex = WarpMutexAtomicImpl; using WarpMutex = WarpMutexAtomicImpl;
const uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; const uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n_set) { new (reinterpret_cast<WarpMutex*>(mutex) + idx) WarpMutex; } if (idx < n_set) { new (reinterpret_cast<WarpMutex*>(mutex) + idx) WarpMutex; }
} }
template<typename Key, typename Elem> template<typename Key, typename Elem>
void ClearLruCacheContext(LruCacheContext<Key, Elem>* ctx) { void ClearLruCacheContext(LruCacheContext<Key, Elem>* ctx) {
OF_CUDA_CHECK(hipMemset(ctx->keys, 0, ctx->n_set * kWarpSize * sizeof(Key))); OF_CUDA_CHECK(hipMemset(ctx->keys, 0, ctx->n_set * kWarpSize * sizeof(Key)));
OF_CUDA_CHECK(hipMemset(ctx->ages, 0, ctx->n_set * kWarpSize * sizeof(uint8_t))); OF_CUDA_CHECK(hipMemset(ctx->ages, 0, ctx->n_set * kWarpSize * sizeof(uint8_t)));
InitCacheSetMutex<<<(ctx->n_set - 1 + 256) / 256, 256>>>(ctx->n_set, ctx->mutex); InitCacheSetMutex<<<(ctx->n_set - 1 + 256) / 256, 256>>>(ctx->n_set, ctx->mutex);
} }
template<typename Key, typename Elem> template<typename Key, typename Elem>
void InitLruCacheContext(const CacheOptions& options, LruCacheContext<Key, Elem>* ctx) { void InitLruCacheContext(const CacheOptions& options, LruCacheContext<Key, Elem>* ctx) {
const size_t keys_size_per_set = kWarpSize * sizeof(Key); const size_t keys_size_per_set = kWarpSize * sizeof(Key);
const uint32_t line_size = options.value_size / sizeof(Elem); const uint32_t line_size = options.value_size / sizeof(Elem);
const size_t lines_size_per_set = kWarpSize * line_size * sizeof(Elem); const size_t lines_size_per_set = kWarpSize * line_size * sizeof(Elem);
const size_t ages_size_per_set = kWarpSize * sizeof(uint8_t); const size_t ages_size_per_set = kWarpSize * sizeof(uint8_t);
int device = 0; int device = 0;
OF_CUDA_CHECK(hipGetDevice(&device)); OF_CUDA_CHECK(hipGetDevice(&device));
int major = 0; int major = 0;
OF_CUDA_CHECK(hipDeviceGetAttribute(&major, hipDeviceAttributeComputeCapabilityMajor, device)); OF_CUDA_CHECK(hipDeviceGetAttribute(&major, hipDeviceAttributeComputeCapabilityMajor, device));
size_t mutex_size_per_set = 0; size_t mutex_size_per_set = 0;
mutex_size_per_set = sizeof(WarpMutexAtomicImpl); mutex_size_per_set = sizeof(WarpMutexAtomicImpl);
const size_t n_set = (options.capacity - 1 + kWarpSize) / kWarpSize; const size_t n_set = (options.capacity - 1 + kWarpSize) / kWarpSize;
CHECK_GT(n_set, 0); CHECK_GT(n_set, 0);
ctx->n_set = n_set; ctx->n_set = n_set;
ctx->line_size = line_size; ctx->line_size = line_size;
const size_t keys_size = n_set * keys_size_per_set; const size_t keys_size = n_set * keys_size_per_set;
OF_CUDA_CHECK(hipMalloc(&(ctx->keys), keys_size)); OF_CUDA_CHECK(hipMalloc(&(ctx->keys), keys_size));
const size_t lines_size = n_set * lines_size_per_set; const size_t lines_size = n_set * lines_size_per_set;
if (options.value_memory_kind == CacheOptions::MemoryKind::kDevice) { if (options.value_memory_kind == CacheOptions::MemoryKind::kDevice) {
OF_CUDA_CHECK(hipMalloc(&(ctx->lines), lines_size)); OF_CUDA_CHECK(hipMalloc(&(ctx->lines), lines_size));
} else if (options.value_memory_kind == CacheOptions::MemoryKind::kHost) { } else if (options.value_memory_kind == CacheOptions::MemoryKind::kHost) {
if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_DISABLE_NUMA_AWARE_ALLOCATION", false)) { if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_DISABLE_NUMA_AWARE_ALLOCATION", false)) {
OF_CUDA_CHECK(hipMallocHost(reinterpret_cast<void **>(&(ctx->lines)), lines_size)); OF_CUDA_CHECK(hipMallocHost(reinterpret_cast<void **>(&(ctx->lines)), lines_size));
} else { } else {
OF_CUDA_CHECK( OF_CUDA_CHECK(
NumaAwareCudaMallocHost(device, reinterpret_cast<void**>(&ctx->lines), lines_size)); NumaAwareCudaMallocHost(device, reinterpret_cast<void**>(&ctx->lines), lines_size));
} }
} else { } else {
UNIMPLEMENTED(); UNIMPLEMENTED();
} }
ctx->value_memory_kind = options.value_memory_kind; ctx->value_memory_kind = options.value_memory_kind;
const size_t ages_size = n_set * ages_size_per_set; const size_t ages_size = n_set * ages_size_per_set;
OF_CUDA_CHECK(hipMalloc(&(ctx->ages), ages_size)); OF_CUDA_CHECK(hipMalloc(&(ctx->ages), ages_size));
const size_t mutex_size = n_set * mutex_size_per_set; const size_t mutex_size = n_set * mutex_size_per_set;
OF_CUDA_CHECK(hipMalloc(&(ctx->mutex), mutex_size)); OF_CUDA_CHECK(hipMalloc(&(ctx->mutex), mutex_size));
ClearLruCacheContext(ctx); ClearLruCacheContext(ctx);
} }
template<typename Key, typename Elem> template<typename Key, typename Elem>
void DestroyLruCacheContext(LruCacheContext<Key, Elem>* ctx) { void DestroyLruCacheContext(LruCacheContext<Key, Elem>* ctx) {
OF_CUDA_CHECK(hipFree(ctx->keys)); OF_CUDA_CHECK(hipFree(ctx->keys));
if (ctx->value_memory_kind == CacheOptions::MemoryKind::kDevice) { if (ctx->value_memory_kind == CacheOptions::MemoryKind::kDevice) {
OF_CUDA_CHECK(hipFree(ctx->lines)); OF_CUDA_CHECK(hipFree(ctx->lines));
} else if (ctx->value_memory_kind == CacheOptions::MemoryKind::kHost) { } else if (ctx->value_memory_kind == CacheOptions::MemoryKind::kHost) {
OF_CUDA_CHECK(hipHostFree(ctx->lines)); OF_CUDA_CHECK(hipHostFree(ctx->lines));
} else { } else {
UNIMPLEMENTED(); UNIMPLEMENTED();
} }
OF_CUDA_CHECK(hipFree(ctx->ages)); OF_CUDA_CHECK(hipFree(ctx->ages));
OF_CUDA_CHECK(hipFree(ctx->mutex)); OF_CUDA_CHECK(hipFree(ctx->mutex));
} }
template<typename Key, typename Elem> template<typename Key, typename Elem>
struct SetContext { struct SetContext {
using WarpMutex = WarpMutexAtomicImpl; using WarpMutex = WarpMutexAtomicImpl;
__device__ SetContext(const LruCacheContext<Key, Elem>& ctx, uint32_t set_id) __device__ SetContext(const LruCacheContext<Key, Elem>& ctx, uint32_t set_id)
: keys(ctx.keys + set_id * kWarpSize), : keys(ctx.keys + set_id * kWarpSize),
mutex(reinterpret_cast<WarpMutex*>(ctx.mutex) + set_id), mutex(reinterpret_cast<WarpMutex*>(ctx.mutex) + set_id),
ages(ctx.ages + set_id * kWarpSize), ages(ctx.ages + set_id * kWarpSize),
lines(ctx.lines + set_id * kWarpSize * ctx.line_size) {} lines(ctx.lines + set_id * kWarpSize * ctx.line_size) {}
__device__ int Lookup(const ThreadContext& thread_ctx, Key key) { __device__ int Lookup(const ThreadContext& thread_ctx, Key key) {
const Key lane_key = keys[thread_ctx.lane_id]; const Key lane_key = keys[thread_ctx.lane_id];
const int lane_age = ages[thread_ctx.lane_id]; const int lane_age = ages[thread_ctx.lane_id];
const bool lane_hit = (lane_key == key && lane_age != 0); const bool lane_hit = (lane_key == key && lane_age != 0);
const unsigned long long int hit_mask = __ballot(lane_hit); const unsigned long long int hit_mask = __ballot(lane_hit);
if (hit_mask != 0) { if (hit_mask != 0) {
return __ffs(static_cast<int>(hit_mask)) - 1; return __ffs(static_cast<int>(hit_mask)) - 1;
} else { } else {
return -1; return -1;
} }
} }
__device__ void Read(const LruCacheContext<Key, Elem>& cache_ctx, const ThreadContext& thread_ctx, __device__ void Read(const LruCacheContext<Key, Elem>& cache_ctx, const ThreadContext& thread_ctx,
int way, Elem* line) { int way, Elem* line) {
const Elem* from_line = lines + way * cache_ctx.line_size; const Elem* from_line = lines + way * cache_ctx.line_size;
for (int i = thread_ctx.lane_id; i < cache_ctx.line_size; i += kWarpSize) { for (int i = thread_ctx.lane_id; i < cache_ctx.line_size; i += kWarpSize) {
line[i] = from_line[i]; line[i] = from_line[i];
} }
} }
__device__ int InsertWithoutEvicting(const LruCacheContext<Key, Elem>& cache_ctx, __device__ int InsertWithoutEvicting(const LruCacheContext<Key, Elem>& cache_ctx,
const ThreadContext& thread_ctx, Key key) { const ThreadContext& thread_ctx, Key key) {
int insert_way = -1; int insert_way = -1;
const Key lane_key = keys[thread_ctx.lane_id]; const Key lane_key = keys[thread_ctx.lane_id];
int lane_age = ages[thread_ctx.lane_id]; int lane_age = ages[thread_ctx.lane_id];
const unsigned long long int hit_mask = __ballot(lane_key == key && lane_age != 0); const unsigned long long int hit_mask = __ballot(lane_key == key && lane_age != 0);
if (hit_mask != 0) { if (hit_mask != 0) {
insert_way = __ffs(static_cast<int>(hit_mask)) - 1; insert_way = __ffs(static_cast<int>(hit_mask)) - 1;
const int insert_way_age = __shfl(lane_age, insert_way); const int insert_way_age = __shfl(lane_age, insert_way);
if (lane_age > insert_way_age) { if (lane_age > insert_way_age) {
lane_age -= 1; lane_age -= 1;
} else if (thread_ctx.lane_id == insert_way) { } else if (thread_ctx.lane_id == insert_way) {
lane_age = kWarpSize; lane_age = kWarpSize;
} }
__syncthreads(); __syncthreads();
} }
if (insert_way == -1) { if (insert_way == -1) {
const unsigned long long int valid_mask = __ballot(lane_age != 0); const unsigned long long int valid_mask = __ballot(lane_age != 0);
if (valid_mask != kFullMask) { if (valid_mask != kFullMask) {
insert_way = __popc(static_cast<int>(valid_mask)); insert_way = __popc(static_cast<int>(valid_mask));
if (lane_age > 0) { if (lane_age > 0) {
lane_age -= 1; lane_age -= 1;
} else if (thread_ctx.lane_id == insert_way) { } else if (thread_ctx.lane_id == insert_way) {
lane_age = kWarpSize; lane_age = kWarpSize;
keys[insert_way] = key; keys[insert_way] = key;
} }
__syncthreads(); __syncthreads();
} }
} }
if (insert_way != -1) { ages[thread_ctx.lane_id] = lane_age; } if (insert_way != -1) { ages[thread_ctx.lane_id] = lane_age; }
return insert_way; return insert_way;
} }
__device__ void Evict(const LruCacheContext<Key, Elem>& cache_ctx, __device__ void Evict(const LruCacheContext<Key, Elem>& cache_ctx,
const ThreadContext& thread_ctx, Key key, int* way, Key* evicted_key) { const ThreadContext& thread_ctx, Key key, int* way, Key* evicted_key) {
const Key lane_key = keys[thread_ctx.lane_id]; const Key lane_key = keys[thread_ctx.lane_id];
int lane_age = ages[thread_ctx.lane_id]; int lane_age = ages[thread_ctx.lane_id];
const int insert_way = __ffs(static_cast<int>(__ballot(lane_age == 1))) - 1; const int insert_way = __ffs(static_cast<int>(__ballot(lane_age == 1))) - 1;
*evicted_key = __shfl(lane_key, insert_way); *evicted_key = __shfl(lane_key, insert_way);
if (thread_ctx.lane_id == insert_way) { if (thread_ctx.lane_id == insert_way) {
keys[insert_way] = key; keys[insert_way] = key;
lane_age = kWarpSize; lane_age = kWarpSize;
} else if (lane_age > 1) { } else if (lane_age > 1) {
lane_age -= 1; lane_age -= 1;
} }
__syncthreads(); __syncthreads();
ages[thread_ctx.lane_id] = lane_age; ages[thread_ctx.lane_id] = lane_age;
*way = insert_way; *way = insert_way;
} }
__device__ void Write(const LruCacheContext<Key, Elem>& cache_ctx, __device__ void Write(const LruCacheContext<Key, Elem>& cache_ctx,
const ThreadContext& thread_ctx, int way, const Elem* line) { const ThreadContext& thread_ctx, int way, const Elem* line) {
Elem* to_line = lines + way * cache_ctx.line_size; Elem* to_line = lines + way * cache_ctx.line_size;
for (int i = thread_ctx.lane_id; i < cache_ctx.line_size; i += kWarpSize) { for (int i = thread_ctx.lane_id; i < cache_ctx.line_size; i += kWarpSize) {
to_line[i] = line[i]; to_line[i] = line[i];
} }
} }
__device__ void Lock(const ThreadContext& thread_ctx) { mutex->Lock(thread_ctx); } __device__ void Lock(const ThreadContext& thread_ctx) { mutex->Lock(thread_ctx); }
__device__ void Unlock(const ThreadContext& thread_ctx) { mutex->Unlock(thread_ctx); } __device__ void Unlock(const ThreadContext& thread_ctx) { mutex->Unlock(thread_ctx); }
Key* keys; Key* keys;
Elem* lines; Elem* lines;
uint8_t* ages; uint8_t* ages;
WarpMutex* mutex; WarpMutex* mutex;
}; };
template<typename Key, typename Elem, bool test_only> template<typename Key, typename Elem, bool test_only>
__global__ void GetKernel(LruCacheContext<Key, Elem> cache_ctx, uint32_t num_keys, const Key* keys, __global__ void GetKernel(LruCacheContext<Key, Elem> cache_ctx, uint32_t num_keys, const Key* keys,
Elem* values, uint32_t* n_missing_keys, Key* missing_keys, Elem* values, uint32_t* n_missing_keys, Key* missing_keys,
uint32_t* missing_indices) { uint32_t* missing_indices) {
ThreadContext thread_ctx{}; ThreadContext thread_ctx{};
__shared__ Key block_keys[kNumWarpPerBlock][kWarpSize]; __shared__ Key block_keys[kNumWarpPerBlock][kWarpSize];
__shared__ size_t block_set_ids[kNumWarpPerBlock][kWarpSize]; __shared__ size_t block_set_ids[kNumWarpPerBlock][kWarpSize];
for (uint32_t batch_offset = thread_ctx.global_warp_id * kWarpSize; batch_offset < num_keys; for (uint32_t batch_offset = thread_ctx.global_warp_id * kWarpSize; batch_offset < num_keys;
batch_offset += thread_ctx.num_warps * kWarpSize) { batch_offset += thread_ctx.num_warps * kWarpSize) {
const uint32_t n_batch_keys = min(kWarpSize, num_keys - batch_offset); const uint32_t n_batch_keys = min(kWarpSize, num_keys - batch_offset);
if (thread_ctx.lane_id < n_batch_keys) { if (thread_ctx.lane_id < n_batch_keys) {
const Key key = keys[batch_offset + thread_ctx.lane_id]; const Key key = keys[batch_offset + thread_ctx.lane_id];
const size_t hash = LruCacheHash()(key); const size_t hash = LruCacheHash()(key);
const uint32_t set_id = hash % cache_ctx.n_set; const uint32_t set_id = hash % cache_ctx.n_set;
block_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = key; block_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = key;
block_set_ids[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = set_id; block_set_ids[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = set_id;
} }
__syncthreads(); __syncthreads();
uint32_t n_warp_missing = 0; uint32_t n_warp_missing = 0;
Key warp_missing_key = 0; Key warp_missing_key = 0;
uint32_t warp_missing_index = 0; uint32_t warp_missing_index = 0;
for (uint32_t i = 0; i < n_batch_keys; ++i) { for (uint32_t i = 0; i < n_batch_keys; ++i) {
const uint32_t key_idx = batch_offset + i; const uint32_t key_idx = batch_offset + i;
const Key key = block_keys[thread_ctx.warp_id_in_block][i]; const Key key = block_keys[thread_ctx.warp_id_in_block][i];
const size_t set_id = block_set_ids[thread_ctx.warp_id_in_block][i]; const size_t set_id = block_set_ids[thread_ctx.warp_id_in_block][i];
SetContext<Key, Elem> set_ctx(cache_ctx, set_id); SetContext<Key, Elem> set_ctx(cache_ctx, set_id);
const int way = set_ctx.Lookup(thread_ctx, key); const int way = set_ctx.Lookup(thread_ctx, key);
if (way < 0) { if (way < 0) {
if (thread_ctx.lane_id == n_warp_missing) { if (thread_ctx.lane_id == n_warp_missing) {
warp_missing_key = key; warp_missing_key = key;
warp_missing_index = key_idx; warp_missing_index = key_idx;
} }
__syncthreads(); __syncthreads();
n_warp_missing += 1; n_warp_missing += 1;
} else if (!test_only) { } else if (!test_only) {
set_ctx.Read(cache_ctx, thread_ctx, way, values + key_idx * cache_ctx.line_size); set_ctx.Read(cache_ctx, thread_ctx, way, values + key_idx * cache_ctx.line_size);
} }
} }
if (n_warp_missing > 0) { if (n_warp_missing > 0) {
uint32_t base_missing_idx = 0; uint32_t base_missing_idx = 0;
if (thread_ctx.lane_id == 0) { base_missing_idx = atomicAdd(n_missing_keys, n_warp_missing); } if (thread_ctx.lane_id == 0) { base_missing_idx = atomicAdd(n_missing_keys, n_warp_missing); }
__syncthreads(); __syncthreads();
base_missing_idx = __shfl(base_missing_idx, 0); base_missing_idx = __shfl(base_missing_idx, 0);
if (thread_ctx.lane_id < n_warp_missing) { if (thread_ctx.lane_id < n_warp_missing) {
missing_keys[base_missing_idx + thread_ctx.lane_id] = warp_missing_key; missing_keys[base_missing_idx + thread_ctx.lane_id] = warp_missing_key;
missing_indices[base_missing_idx + thread_ctx.lane_id] = warp_missing_index; missing_indices[base_missing_idx + thread_ctx.lane_id] = warp_missing_index;
} }
__syncthreads(); __syncthreads();
} }
__syncthreads(); __syncthreads();
} }
} }
template<typename Key, typename Elem> template<typename Key, typename Elem>
__global__ void PutWithoutEvictingKernel(LruCacheContext<Key, Elem> cache_ctx, uint32_t num_keys, __global__ void PutWithoutEvictingKernel(LruCacheContext<Key, Elem> cache_ctx, uint32_t num_keys,
const Key* keys, const Elem* values, uint32_t* n_missing, const Key* keys, const Elem* values, uint32_t* n_missing,
Key* missing_keys, uint32_t* missing_indices) { Key* missing_keys, uint32_t* missing_indices) {
ThreadContext thread_ctx{}; ThreadContext thread_ctx{};
__shared__ Key block_keys[kNumWarpPerBlock][kWarpSize]; __shared__ Key block_keys[kNumWarpPerBlock][kWarpSize];
__shared__ size_t block_set_ids[kNumWarpPerBlock][kWarpSize]; __shared__ size_t block_set_ids[kNumWarpPerBlock][kWarpSize];
for (uint32_t batch_offset = thread_ctx.global_warp_id * kWarpSize; batch_offset < num_keys; for (uint32_t batch_offset = thread_ctx.global_warp_id * kWarpSize; batch_offset < num_keys;
batch_offset += thread_ctx.num_warps * kWarpSize) { batch_offset += thread_ctx.num_warps * kWarpSize) {
const uint32_t n_batch_keys = min(kWarpSize, num_keys - batch_offset); const uint32_t n_batch_keys = min(kWarpSize, num_keys - batch_offset);
if (thread_ctx.lane_id < n_batch_keys) { if (thread_ctx.lane_id < n_batch_keys) {
const Key key = keys[batch_offset + thread_ctx.lane_id]; const Key key = keys[batch_offset + thread_ctx.lane_id];
const size_t hash = LruCacheHash()(key); const size_t hash = LruCacheHash()(key);
const uint32_t set_id = hash % cache_ctx.n_set; const uint32_t set_id = hash % cache_ctx.n_set;
block_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = key; block_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = key;
block_set_ids[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = set_id; block_set_ids[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = set_id;
} }
__syncthreads(); __syncthreads();
uint32_t n_warp_missing = 0; uint32_t n_warp_missing = 0;
Key warp_missing_key = 0; Key warp_missing_key = 0;
uint32_t warp_missing_index = 0; uint32_t warp_missing_index = 0;
for (uint32_t i = 0; i < n_batch_keys; ++i) { for (uint32_t i = 0; i < n_batch_keys; ++i) {
const uint32_t key_idx = batch_offset + i; const uint32_t key_idx = batch_offset + i;
const Key key = block_keys[thread_ctx.warp_id_in_block][i]; const Key key = block_keys[thread_ctx.warp_id_in_block][i];
const size_t set_id = block_set_ids[thread_ctx.warp_id_in_block][i]; const size_t set_id = block_set_ids[thread_ctx.warp_id_in_block][i];
SetContext<Key, Elem> set_ctx(cache_ctx, set_id); SetContext<Key, Elem> set_ctx(cache_ctx, set_id);
set_ctx.Lock(thread_ctx); set_ctx.Lock(thread_ctx);
Key evicted_key = 0; Key evicted_key = 0;
const int insert_way = set_ctx.InsertWithoutEvicting(cache_ctx, thread_ctx, key); const int insert_way = set_ctx.InsertWithoutEvicting(cache_ctx, thread_ctx, key);
if (insert_way >= 0) { if (insert_way >= 0) {
set_ctx.Write(cache_ctx, thread_ctx, insert_way, values + cache_ctx.line_size * key_idx); set_ctx.Write(cache_ctx, thread_ctx, insert_way, values + cache_ctx.line_size * key_idx);
} else { } else {
if (thread_ctx.lane_id == n_warp_missing) { if (thread_ctx.lane_id == n_warp_missing) {
warp_missing_key = key; warp_missing_key = key;
warp_missing_index = key_idx; warp_missing_index = key_idx;
} }
__syncthreads(); __syncthreads();
n_warp_missing += 1; n_warp_missing += 1;
} }
set_ctx.Unlock(thread_ctx); set_ctx.Unlock(thread_ctx);
} }
if (n_warp_missing > 0) { if (n_warp_missing > 0) {
uint32_t base_missing_idx = 0; uint32_t base_missing_idx = 0;
if (thread_ctx.lane_id == 0) { base_missing_idx = atomicAdd(n_missing, n_warp_missing); } if (thread_ctx.lane_id == 0) { base_missing_idx = atomicAdd(n_missing, n_warp_missing); }
__syncthreads(); __syncthreads();
base_missing_idx = __shfl(base_missing_idx, 0); base_missing_idx = __shfl(base_missing_idx, 0);
if (thread_ctx.lane_id < n_warp_missing) { if (thread_ctx.lane_id < n_warp_missing) {
missing_keys[base_missing_idx + thread_ctx.lane_id] = warp_missing_key; missing_keys[base_missing_idx + thread_ctx.lane_id] = warp_missing_key;
missing_indices[base_missing_idx + thread_ctx.lane_id] = warp_missing_index; missing_indices[base_missing_idx + thread_ctx.lane_id] = warp_missing_index;
} }
__syncthreads(); __syncthreads();
} }
} }
} }
template<typename Key, typename Elem> template<typename Key, typename Elem>
__global__ void EvictKernel(LruCacheContext<Key, Elem> cache_ctx, const Key* keys, __global__ void EvictKernel(LruCacheContext<Key, Elem> cache_ctx, const Key* keys,
const uint32_t* indices, const Elem* values, const uint32_t* n_evict, const uint32_t* indices, const Elem* values, const uint32_t* n_evict,
Key* evicted_keys, Elem* evicted_values) { Key* evicted_keys, Elem* evicted_values) {
ThreadContext thread_ctx{}; ThreadContext thread_ctx{};
uint32_t num_evict = *n_evict; uint32_t num_evict = *n_evict;
__shared__ Key block_keys[kNumWarpPerBlock][kWarpSize]; __shared__ Key block_keys[kNumWarpPerBlock][kWarpSize];
__shared__ size_t block_set_ids[kNumWarpPerBlock][kWarpSize]; __shared__ size_t block_set_ids[kNumWarpPerBlock][kWarpSize];
for (uint32_t batch_offset = thread_ctx.global_warp_id * kWarpSize; batch_offset < num_evict; for (uint32_t batch_offset = thread_ctx.global_warp_id * kWarpSize; batch_offset < num_evict;
batch_offset += thread_ctx.num_warps * kWarpSize) { batch_offset += thread_ctx.num_warps * kWarpSize) {
const uint32_t n_batch_keys = min(kWarpSize, num_evict - batch_offset); const uint32_t n_batch_keys = min(kWarpSize, num_evict - batch_offset);
if (thread_ctx.lane_id < n_batch_keys) { if (thread_ctx.lane_id < n_batch_keys) {
const Key key = keys[batch_offset + thread_ctx.lane_id]; const Key key = keys[batch_offset + thread_ctx.lane_id];
const size_t hash = LruCacheHash()(key); const size_t hash = LruCacheHash()(key);
const uint32_t set_id = hash % cache_ctx.n_set; const uint32_t set_id = hash % cache_ctx.n_set;
block_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = key; block_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = key;
block_set_ids[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = set_id; block_set_ids[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = set_id;
} }
__syncthreads(); __syncthreads();
for (uint32_t i = 0; i < n_batch_keys; ++i) { for (uint32_t i = 0; i < n_batch_keys; ++i) {
const uint32_t key_idx = batch_offset + i; const uint32_t key_idx = batch_offset + i;
const Key key = block_keys[thread_ctx.warp_id_in_block][i]; const Key key = block_keys[thread_ctx.warp_id_in_block][i];
const uint32_t set_id = block_set_ids[thread_ctx.warp_id_in_block][i]; const uint32_t set_id = block_set_ids[thread_ctx.warp_id_in_block][i];
SetContext<Key, Elem> set_ctx(cache_ctx, set_id); SetContext<Key, Elem> set_ctx(cache_ctx, set_id);
set_ctx.Lock(thread_ctx); set_ctx.Lock(thread_ctx);
int evicted_way = -1; int evicted_way = -1;
Key evicted_key = 0; Key evicted_key = 0;
set_ctx.Evict(cache_ctx, thread_ctx, key, &evicted_way, &evicted_key); set_ctx.Evict(cache_ctx, thread_ctx, key, &evicted_way, &evicted_key);
if (thread_ctx.lane_id == 0) { evicted_keys[key_idx] = evicted_key; } if (thread_ctx.lane_id == 0) { evicted_keys[key_idx] = evicted_key; }
__syncthreads(); __syncthreads();
set_ctx.Read(cache_ctx, thread_ctx, evicted_way, set_ctx.Read(cache_ctx, thread_ctx, evicted_way,
evicted_values + cache_ctx.line_size * key_idx); evicted_values + cache_ctx.line_size * key_idx);
set_ctx.Write(cache_ctx, thread_ctx, evicted_way, set_ctx.Write(cache_ctx, thread_ctx, evicted_way,
values + cache_ctx.line_size * indices[key_idx]); values + cache_ctx.line_size * indices[key_idx]);
set_ctx.Unlock(thread_ctx); set_ctx.Unlock(thread_ctx);
} }
} }
} }
template<typename Key, typename Elem> template<typename Key, typename Elem>
__global__ void DumpKernel(LruCacheContext<Key, Elem> cache_ctx, size_t start_key_index, __global__ void DumpKernel(LruCacheContext<Key, Elem> cache_ctx, size_t start_key_index,
size_t end_key_index, uint32_t* n_dumped, Key* keys, Elem* values) { size_t end_key_index, uint32_t* n_dumped, Key* keys, Elem* values) {
ThreadContext thread_ctx{}; ThreadContext thread_ctx{};
__shared__ Key warp_keys[kNumWarpPerBlock][kWarpSize]; __shared__ Key warp_keys[kNumWarpPerBlock][kWarpSize];
__shared__ uint8_t warp_ages[kNumWarpPerBlock][kWarpSize]; __shared__ uint8_t warp_ages[kNumWarpPerBlock][kWarpSize];
for (uint32_t warp_start_key_index = start_key_index + thread_ctx.global_warp_id * kWarpSize; for (uint32_t warp_start_key_index = start_key_index + thread_ctx.global_warp_id * kWarpSize;
warp_start_key_index < end_key_index; warp_start_key_index < end_key_index;
warp_start_key_index += thread_ctx.num_warps * kWarpSize) { warp_start_key_index += thread_ctx.num_warps * kWarpSize) {
Key lane_key = 0; Key lane_key = 0;
uint8_t lane_age = 0; uint8_t lane_age = 0;
if (warp_start_key_index + thread_ctx.lane_id < end_key_index) { if (warp_start_key_index + thread_ctx.lane_id < end_key_index) {
lane_key = cache_ctx.keys[warp_start_key_index + thread_ctx.lane_id]; lane_key = cache_ctx.keys[warp_start_key_index + thread_ctx.lane_id];
lane_age = cache_ctx.ages[warp_start_key_index + thread_ctx.lane_id]; lane_age = cache_ctx.ages[warp_start_key_index + thread_ctx.lane_id];
} }
__syncthreads(); __syncthreads();
warp_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_key; warp_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_key;
warp_ages[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_age; warp_ages[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_age;
const int key_count = __popc(static_cast<int>(__ballot(lane_age != 0))); const int key_count = __popc(static_cast<int>(__ballot(lane_age != 0)));
if (key_count == 0) { continue; } if (key_count == 0) { continue; }
uint32_t offset = 0; uint32_t offset = 0;
if (thread_ctx.lane_id == 0) { offset = atomicAdd(n_dumped, key_count); } if (thread_ctx.lane_id == 0) { offset = atomicAdd(n_dumped, key_count); }
offset = __shfl(offset, 0); offset = __shfl(offset, 0);
__syncthreads(); __syncthreads();
for (uint32_t i = 0; i < kWarpSize; ++i) { for (uint32_t i = 0; i < kWarpSize; ++i) {
const Key key = warp_keys[thread_ctx.warp_id_in_block][i]; const Key key = warp_keys[thread_ctx.warp_id_in_block][i];
const Key age = warp_ages[thread_ctx.warp_id_in_block][i]; const Key age = warp_ages[thread_ctx.warp_id_in_block][i];
if (age == 0) { continue; } if (age == 0) { continue; }
if (thread_ctx.lane_id == 0) { keys[offset] = key; } if (thread_ctx.lane_id == 0) { keys[offset] = key; }
__syncthreads(); __syncthreads();
for (uint32_t j = thread_ctx.lane_id; j < cache_ctx.line_size; j += kWarpSize) { for (uint32_t j = thread_ctx.lane_id; j < cache_ctx.line_size; j += kWarpSize) {
values[offset * cache_ctx.line_size + j] = values[offset * cache_ctx.line_size + j] =
cache_ctx.lines[(warp_start_key_index + i) * cache_ctx.line_size + j]; cache_ctx.lines[(warp_start_key_index + i) * cache_ctx.line_size + j];
} }
__syncthreads(); __syncthreads();
offset += 1; offset += 1;
} }
} }
} }
template<typename Key, typename Elem> template<typename Key, typename Elem>
class LruCache : public Cache { class LruCache : public Cache {
public: public:
OF_DISALLOW_COPY_AND_MOVE(LruCache); OF_DISALLOW_COPY_AND_MOVE(LruCache);
explicit LruCache(const CacheOptions& options) explicit LruCache(const CacheOptions& options)
: device_index_{}, : device_index_{},
max_query_length_(0), max_query_length_(0),
query_indices_buffer_(nullptr), query_indices_buffer_(nullptr),
query_keys_buffer_(nullptr), query_keys_buffer_(nullptr),
value_type_(options.value_type) { value_type_(options.value_type) {
OF_CUDA_CHECK(hipGetDevice(&device_index_)); OF_CUDA_CHECK(hipGetDevice(&device_index_));
InitLruCacheContext(options, &ctx_); InitLruCacheContext(options, &ctx_);
} }
~LruCache() override { ~LruCache() override {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
if (max_query_length_ != 0) { if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipFree(query_indices_buffer_)); OF_CUDA_CHECK(hipFree(query_indices_buffer_));
OF_CUDA_CHECK(hipFree(query_keys_buffer_)); OF_CUDA_CHECK(hipFree(query_keys_buffer_));
} }
DestroyLruCacheContext(&ctx_); DestroyLruCacheContext(&ctx_);
} }
uint32_t KeySize() const override { return sizeof(Key); } uint32_t KeySize() const override { return sizeof(Key); }
uint32_t ValueSize() const override { return sizeof(Elem) * ctx_.line_size; } uint32_t ValueSize() const override { return sizeof(Elem) * ctx_.line_size; }
DataType ValueType() const override { return value_type_; } DataType ValueType() const override { return value_type_; }
uint64_t Capacity() const override { return ctx_.n_set * kWarpSize; } uint64_t Capacity() const override { return ctx_.n_set * kWarpSize; }
uint32_t MaxQueryLength() const override { return max_query_length_; } uint32_t MaxQueryLength() const override { return max_query_length_; }
void ReserveQueryLength(uint32_t query_length) override { void ReserveQueryLength(uint32_t query_length) override {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
if (query_length < max_query_length_) { return; } if (query_length < max_query_length_) { return; }
if (max_query_length_ != 0) { if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipFree(query_indices_buffer_)); OF_CUDA_CHECK(hipFree(query_indices_buffer_));
OF_CUDA_CHECK(hipFree(query_keys_buffer_)); OF_CUDA_CHECK(hipFree(query_keys_buffer_));
} }
OF_CUDA_CHECK(hipMalloc(&query_indices_buffer_, query_length * sizeof(uint32_t))); OF_CUDA_CHECK(hipMalloc(&query_indices_buffer_, query_length * sizeof(uint32_t)));
OF_CUDA_CHECK(hipMalloc(&query_keys_buffer_, query_length * sizeof(Key))); OF_CUDA_CHECK(hipMalloc(&query_keys_buffer_, query_length * sizeof(Key)));
max_query_length_ = query_length; max_query_length_ = query_length;
} }
CacheOptions::Policy Policy() const override { return CacheOptions::Policy::kLRU; } CacheOptions::Policy Policy() const override { return CacheOptions::Policy::kLRU; }
void Test(ep::Stream* stream, uint32_t n_keys, const void* keys, uint32_t* n_missing, void Test(ep::Stream* stream, uint32_t n_keys, const void* keys, uint32_t* n_missing,
void* missing_keys, uint32_t* missing_indices) override { void* missing_keys, uint32_t* missing_indices) override {
CHECK_LE(n_keys, max_query_length_); CHECK_LE(n_keys, max_query_length_);
auto cuda_stream = stream->As<ep::CudaStream>(); auto cuda_stream = stream->As<ep::CudaStream>();
OF_CUDA_CHECK(hipMemsetAsync(n_missing, 0, sizeof(uint32_t), cuda_stream->cuda_stream())); OF_CUDA_CHECK(hipMemsetAsync(n_missing, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));
if (n_keys == 0) { return; } if (n_keys == 0) { return; }
cuda_stream->LaunchKernel(GetKernel<Key, Elem, true>, GetLaunchConfig(n_keys), ctx_, n_keys, cuda_stream->LaunchKernel(GetKernel<Key, Elem, true>, GetLaunchConfig(n_keys), ctx_, n_keys,
static_cast<const Key*>(keys), nullptr, n_missing, static_cast<const Key*>(keys), nullptr, n_missing,
static_cast<Key*>(missing_keys), missing_indices); static_cast<Key*>(missing_keys), missing_indices);
} }
void Get(ep::Stream* stream, uint32_t n_keys, const void* keys, void* values, uint32_t* n_missing, void Get(ep::Stream* stream, uint32_t n_keys, const void* keys, void* values, uint32_t* n_missing,
void* missing_keys, uint32_t* missing_indices) override { void* missing_keys, uint32_t* missing_indices) override {
CHECK_LE(n_keys, max_query_length_); CHECK_LE(n_keys, max_query_length_);
auto cuda_stream = stream->As<ep::CudaStream>(); auto cuda_stream = stream->As<ep::CudaStream>();
OF_CUDA_CHECK(hipMemsetAsync(n_missing, 0, sizeof(uint32_t), cuda_stream->cuda_stream())); OF_CUDA_CHECK(hipMemsetAsync(n_missing, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));
if (n_keys == 0) { return; } if (n_keys == 0) { return; }
cuda_stream->LaunchKernel(GetKernel<Key, Elem, false>, GetLaunchConfig(n_keys), ctx_, n_keys, cuda_stream->LaunchKernel(GetKernel<Key, Elem, false>, GetLaunchConfig(n_keys), ctx_, n_keys,
static_cast<const Key*>(keys), static_cast<Elem*>(values), n_missing, static_cast<const Key*>(keys), static_cast<Elem*>(values), n_missing,
static_cast<Key*>(missing_keys), missing_indices); static_cast<Key*>(missing_keys), missing_indices);
} }
void Put(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values, void Put(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values,
uint32_t* n_evicted, void* evicted_keys, void* evicted_values) override { uint32_t* n_evicted, void* evicted_keys, void* evicted_values) override {
CHECK_LE(n_keys, max_query_length_); CHECK_LE(n_keys, max_query_length_);
auto cuda_stream = stream->As<ep::CudaStream>(); auto cuda_stream = stream->As<ep::CudaStream>();
OF_CUDA_CHECK(hipMemsetAsync(n_evicted, 0, sizeof(uint32_t), cuda_stream->cuda_stream())); OF_CUDA_CHECK(hipMemsetAsync(n_evicted, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));
if (n_keys == 0) { return; } if (n_keys == 0) { return; }
cuda_stream->LaunchKernel(PutWithoutEvictingKernel<Key, Elem>, GetLaunchConfig(n_keys), ctx_, cuda_stream->LaunchKernel(PutWithoutEvictingKernel<Key, Elem>, GetLaunchConfig(n_keys), ctx_,
n_keys, static_cast<const Key*>(keys), n_keys, static_cast<const Key*>(keys),
static_cast<const Elem*>(values), n_evicted, query_keys_buffer_, static_cast<const Elem*>(values), n_evicted, query_keys_buffer_,
query_indices_buffer_); query_indices_buffer_);
cuda_stream->LaunchKernel(EvictKernel<Key, Elem>, GetLaunchConfig(n_keys), ctx_, cuda_stream->LaunchKernel(EvictKernel<Key, Elem>, GetLaunchConfig(n_keys), ctx_,
query_keys_buffer_, query_indices_buffer_, query_keys_buffer_, query_indices_buffer_,
static_cast<const Elem*>(values), n_evicted, static_cast<const Elem*>(values), n_evicted,
static_cast<Key*>(evicted_keys), static_cast<Elem*>(evicted_values)); static_cast<Key*>(evicted_keys), static_cast<Elem*>(evicted_values));
} }
void Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index, void Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index,
uint32_t* n_dumped, void* keys, void* values) override { uint32_t* n_dumped, void* keys, void* values) override {
auto cuda_stream = stream->As<ep::CudaStream>(); auto cuda_stream = stream->As<ep::CudaStream>();
OF_CUDA_CHECK(hipMemsetAsync(n_dumped, 0, sizeof(uint32_t), cuda_stream->cuda_stream())); OF_CUDA_CHECK(hipMemsetAsync(n_dumped, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));
const uint64_t max_dump_keys = end_key_index - start_key_index; const uint64_t max_dump_keys = end_key_index - start_key_index;
cuda_stream->LaunchKernel( cuda_stream->LaunchKernel(
DumpKernel<Key, Elem>, DumpKernel<Key, Elem>,
ep::CudaLaunchConfig((max_dump_keys + kNumWarpPerBlock - 1) / kNumWarpPerBlock, kBlockSize, ep::CudaLaunchConfig((max_dump_keys + kNumWarpPerBlock - 1) / kNumWarpPerBlock, kBlockSize,
0), 0),
ctx_, start_key_index, end_key_index, n_dumped, static_cast<Key*>(keys), ctx_, start_key_index, end_key_index, n_dumped, static_cast<Key*>(keys),
static_cast<Elem*>(values)); static_cast<Elem*>(values));
} }
void Clear() override { ClearLruCacheContext<Key, Elem>(&ctx_); } void Clear() override { ClearLruCacheContext<Key, Elem>(&ctx_); }
private: private:
int device_index_; int device_index_;
uint32_t max_query_length_; uint32_t max_query_length_;
LruCacheContext<Key, Elem> ctx_; LruCacheContext<Key, Elem> ctx_;
uint32_t* query_indices_buffer_; uint32_t* query_indices_buffer_;
Key* query_keys_buffer_; Key* query_keys_buffer_;
DataType value_type_; DataType value_type_;
}; };
template<typename Key> template<typename Key>
std::unique_ptr<Cache> DispatchValueType(const CacheOptions& options) { std::unique_ptr<Cache> DispatchValueType(const CacheOptions& options) {
if (options.value_size % sizeof(ulonglong2) == 0) { if (options.value_size % sizeof(ulonglong2) == 0) {
return std::unique_ptr<Cache>(new LruCache<Key, ulonglong2>(options)); return std::unique_ptr<Cache>(new LruCache<Key, ulonglong2>(options));
} else if (options.value_size % sizeof(uint64_t) == 0) { } else if (options.value_size % sizeof(uint64_t) == 0) {
return std::unique_ptr<Cache>(new LruCache<Key, uint64_t>(options)); return std::unique_ptr<Cache>(new LruCache<Key, uint64_t>(options));
} else if (options.value_size % sizeof(uint32_t) == 0) { } else if (options.value_size % sizeof(uint32_t) == 0) {
return std::unique_ptr<Cache>(new LruCache<Key, uint32_t>(options)); return std::unique_ptr<Cache>(new LruCache<Key, uint32_t>(options));
} else if (options.value_size % sizeof(uint16_t) == 0) { } else if (options.value_size % sizeof(uint16_t) == 0) {
return std::unique_ptr<Cache>(new LruCache<Key, uint16_t>(options)); return std::unique_ptr<Cache>(new LruCache<Key, uint16_t>(options));
} else { } else {
return std::unique_ptr<Cache>(new LruCache<Key, uint8_t>(options)); return std::unique_ptr<Cache>(new LruCache<Key, uint8_t>(options));
} }
} }
std::unique_ptr<Cache> DispatchKeyType(const CacheOptions& options) { std::unique_ptr<Cache> DispatchKeyType(const CacheOptions& options) {
if (options.key_size == sizeof(uint32_t)) { if (options.key_size == sizeof(uint32_t)) {
return DispatchValueType<uint32_t>(options); return DispatchValueType<uint32_t>(options);
} else if (options.key_size == sizeof(uint64_t)) { } else if (options.key_size == sizeof(uint64_t)) {
return DispatchValueType<uint64_t>(options); return DispatchValueType<uint64_t>(options);
} else { } else {
UNIMPLEMENTED(); UNIMPLEMENTED();
return nullptr; return nullptr;
} }
} }
} // namespace } // namespace
std::unique_ptr<Cache> NewLruCache(const CacheOptions& options) { return DispatchKeyType(options); } std::unique_ptr<Cache> NewLruCache(const CacheOptions& options) { return DispatchKeyType(options); }
} // namespace embedding } // namespace embedding
} // namespace oneflow } // namespace oneflow
\ No newline at end of file
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/embedding/mock_key_value_store.h" #include "oneflow/core/embedding/mock_key_value_store.h"
#include "oneflow/core/device/cuda_util.h" #include "oneflow/core/device/cuda_util.h"
namespace oneflow { namespace oneflow {
namespace embedding { namespace embedding {
namespace { namespace {
template<typename Key> template<typename Key>
class IteratorImpl : public KVIterator { class IteratorImpl : public KVIterator {
public: public:
OF_DISALLOW_COPY_AND_MOVE(IteratorImpl); OF_DISALLOW_COPY_AND_MOVE(IteratorImpl);
IteratorImpl(HashMap<Key, std::string>* store, uint32_t key_size, uint32_t value_size, IteratorImpl(HashMap<Key, std::string>* store, uint32_t key_size, uint32_t value_size,
uint32_t max_query_length, void* host_keys_buffer, void* host_values_buffer, uint32_t max_query_length, void* host_keys_buffer, void* host_values_buffer,
uint32_t* host_num_buffer) uint32_t* host_num_buffer)
: store_(store), : store_(store),
pos_(store->begin()), pos_(store->begin()),
key_size_(key_size), key_size_(key_size),
value_size_(value_size), value_size_(value_size),
max_query_length_(max_query_length), max_query_length_(max_query_length),
host_keys_buffer_(host_keys_buffer), host_keys_buffer_(host_keys_buffer),
host_values_buffer_(host_values_buffer), host_values_buffer_(host_values_buffer),
host_num_buffer_(host_num_buffer) {} host_num_buffer_(host_num_buffer) {}
~IteratorImpl() override = default; ~IteratorImpl() override = default;
void NextN(ep::Stream* stream, uint32_t n_request, uint32_t* n_result, void* keys, void NextN(ep::Stream* stream, uint32_t n_request, uint32_t* n_result, void* keys,
void* values) override { void* values) override {
CHECK_LE(n_request, max_query_length_); CHECK_LE(n_request, max_query_length_);
auto cuda_stream = stream->As<ep::CudaStream>(); auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_JUST(cuda_stream->Sync()); CHECK_JUST(cuda_stream->Sync());
*host_num_buffer_ = 0; *host_num_buffer_ = 0;
while (*host_num_buffer_ < n_request && pos_ != store_->end()) { while (*host_num_buffer_ < n_request && pos_ != store_->end()) {
reinterpret_cast<Key*>(host_keys_buffer_)[*host_num_buffer_] = pos_->first; reinterpret_cast<Key*>(host_keys_buffer_)[*host_num_buffer_] = pos_->first;
std::memcpy(reinterpret_cast<char*>(host_values_buffer_) + *host_num_buffer_ * value_size_, std::memcpy(reinterpret_cast<char*>(host_values_buffer_) + *host_num_buffer_ * value_size_,
pos_->second.data(), value_size_); pos_->second.data(), value_size_);
} }
OF_CUDA_CHECK(hipMemcpyAsync(n_result, host_num_buffer_, sizeof(uint32_t), hipMemcpyDefault, OF_CUDA_CHECK(hipMemcpyAsync(n_result, host_num_buffer_, sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream())); cuda_stream->cuda_stream()));
const uint32_t num_keys = *host_num_buffer_; const uint32_t num_keys = *host_num_buffer_;
if (num_keys != 0) { if (num_keys != 0) {
OF_CUDA_CHECK(hipMemcpyAsync(keys, host_keys_buffer_, num_keys * key_size_, OF_CUDA_CHECK(hipMemcpyAsync(keys, host_keys_buffer_, num_keys * key_size_,
hipMemcpyDefault, cuda_stream->cuda_stream())); hipMemcpyDefault, cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(values, host_values_buffer_, num_keys * value_size_, OF_CUDA_CHECK(hipMemcpyAsync(values, host_values_buffer_, num_keys * value_size_,
hipMemcpyDefault, cuda_stream->cuda_stream())); hipMemcpyDefault, cuda_stream->cuda_stream()));
} }
} }
void Reset() override { pos_ = store_->begin(); } void Reset() override { pos_ = store_->begin(); }
private: private:
HashMap<Key, std::string>* store_; HashMap<Key, std::string>* store_;
typename HashMap<Key, std::string>::iterator pos_; typename HashMap<Key, std::string>::iterator pos_;
uint32_t key_size_; uint32_t key_size_;
uint32_t value_size_; uint32_t value_size_;
uint32_t max_query_length_; uint32_t max_query_length_;
void* host_keys_buffer_; void* host_keys_buffer_;
void* host_values_buffer_; void* host_values_buffer_;
uint32_t* host_num_buffer_; uint32_t* host_num_buffer_;
}; };
template<typename Key> template<typename Key>
class KeyValueStoreImpl : public KeyValueStore { class KeyValueStoreImpl : public KeyValueStore {
public: public:
OF_DISALLOW_COPY_AND_MOVE(KeyValueStoreImpl); OF_DISALLOW_COPY_AND_MOVE(KeyValueStoreImpl);
explicit KeyValueStoreImpl(const MockKeyValueStoreOptions& options) explicit KeyValueStoreImpl(const MockKeyValueStoreOptions& options)
: device_index_(-1), max_query_length_(0) { : device_index_(-1), max_query_length_(0) {
OF_CUDA_CHECK(hipGetDevice(&device_index_)); OF_CUDA_CHECK(hipGetDevice(&device_index_));
key_size_ = options.key_size; key_size_ = options.key_size;
value_size_ = options.value_size; value_size_ = options.value_size;
OF_CUDA_CHECK(NumaAwareCudaMallocHost( OF_CUDA_CHECK(NumaAwareCudaMallocHost(
device_index_, reinterpret_cast<void**>(&host_query_keys_), key_size_ * max_query_length_)); device_index_, reinterpret_cast<void**>(&host_query_keys_), key_size_ * max_query_length_));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,
reinterpret_cast<void**>(&host_query_values_), reinterpret_cast<void**>(&host_query_values_),
value_size_ * max_query_length_)); value_size_ * max_query_length_));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, reinterpret_cast<void**>(&host_n_missing_), OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, reinterpret_cast<void**>(&host_n_missing_),
sizeof(uint32_t))); sizeof(uint32_t)));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,
reinterpret_cast<void**>(&host_missing_indices_), reinterpret_cast<void**>(&host_missing_indices_),
sizeof(uint32_t) * max_query_length_)); sizeof(uint32_t) * max_query_length_));
} }
~KeyValueStoreImpl() { ~KeyValueStoreImpl() {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
if (max_query_length_ != 0) { if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipHostFree(host_query_keys_)); OF_CUDA_CHECK(hipHostFree(host_query_keys_));
OF_CUDA_CHECK(hipHostFree(host_query_values_)); OF_CUDA_CHECK(hipHostFree(host_query_values_));
OF_CUDA_CHECK(hipHostFree(host_missing_indices_)); OF_CUDA_CHECK(hipHostFree(host_missing_indices_));
} }
OF_CUDA_CHECK(hipHostFree(host_n_missing_)); OF_CUDA_CHECK(hipHostFree(host_n_missing_));
} }
uint32_t KeySize() const override { return key_size_; } uint32_t KeySize() const override { return key_size_; }
uint32_t ValueSize() const override { return value_size_; } uint32_t ValueSize() const override { return value_size_; }
uint32_t MaxQueryLength() const override { return max_query_length_; } uint32_t MaxQueryLength() const override { return max_query_length_; }
void ReserveQueryLength(uint32_t query_length) override { void ReserveQueryLength(uint32_t query_length) override {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
if (query_length <= max_query_length_) { return; } if (query_length <= max_query_length_) { return; }
if (max_query_length_ != 0) { if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipHostFree(host_query_keys_)); OF_CUDA_CHECK(hipHostFree(host_query_keys_));
OF_CUDA_CHECK(hipHostFree(host_query_values_)); OF_CUDA_CHECK(hipHostFree(host_query_values_));
OF_CUDA_CHECK(hipHostFree(host_missing_indices_)); OF_CUDA_CHECK(hipHostFree(host_missing_indices_));
} }
OF_CUDA_CHECK(NumaAwareCudaMallocHost( OF_CUDA_CHECK(NumaAwareCudaMallocHost(
device_index_, reinterpret_cast<void**>(&host_query_keys_), key_size_ * query_length)); device_index_, reinterpret_cast<void**>(&host_query_keys_), key_size_ * query_length));
OF_CUDA_CHECK(NumaAwareCudaMallocHost( OF_CUDA_CHECK(NumaAwareCudaMallocHost(
device_index_, reinterpret_cast<void**>(&host_query_values_), value_size_ * query_length)); device_index_, reinterpret_cast<void**>(&host_query_values_), value_size_ * query_length));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,
reinterpret_cast<void**>(&host_missing_indices_), reinterpret_cast<void**>(&host_missing_indices_),
sizeof(uint32_t) * query_length)); sizeof(uint32_t) * query_length));
max_query_length_ = query_length; max_query_length_ = query_length;
} }
void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values, void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values,
uint32_t* n_missing, uint32_t* missing_indices) override; uint32_t* n_missing, uint32_t* missing_indices) override;
void Put(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values) override; void Put(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values) override;
bool SnapshotExists(const std::string& name) override; bool SnapshotExists(const std::string& name) override;
void LoadSnapshot(const std::string& name) override; void LoadSnapshot(const std::string& name) override;
void LoadSnapshot(const std::string& name, void LoadSnapshot(const std::string& name,
const std::function<void(KVIterator* iter)>& Hook) override; const std::function<void(KVIterator* iter)>& Hook) override;
void SaveSnapshot(const std::string& name) override; void SaveSnapshot(const std::string& name) override;
private: private:
int device_index_; int device_index_;
uint32_t max_query_length_; uint32_t max_query_length_;
uint32_t key_size_; uint32_t key_size_;
uint32_t value_size_; uint32_t value_size_;
Key* host_query_keys_{}; Key* host_query_keys_{};
uint8_t* host_query_values_{}; uint8_t* host_query_values_{};
uint32_t* host_n_missing_{}; uint32_t* host_n_missing_{};
uint32_t* host_missing_indices_{}; uint32_t* host_missing_indices_{};
HashMap<Key, std::string> store_; HashMap<Key, std::string> store_;
HashMap<std::string, HashMap<Key, std::string>> snapshots_; HashMap<std::string, HashMap<Key, std::string>> snapshots_;
std::mutex mutex_; std::mutex mutex_;
}; };
template<typename Key> template<typename Key>
void KeyValueStoreImpl<Key>::Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void KeyValueStoreImpl<Key>::Get(ep::Stream* stream, uint32_t num_keys, const void* keys,
void* values, uint32_t* n_missing, uint32_t* missing_indices) { void* values, uint32_t* n_missing, uint32_t* missing_indices) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
auto cuda_stream = stream->As<ep::CudaStream>(); auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_LE(num_keys, max_query_length_); CHECK_LE(num_keys, max_query_length_);
if (num_keys == 0) { if (num_keys == 0) {
OF_CUDA_CHECK(hipMemsetAsync(n_missing, 0, sizeof(uint32_t), OF_CUDA_CHECK(hipMemsetAsync(n_missing, 0, sizeof(uint32_t),
stream->As<ep::CudaStream>()->cuda_stream())); stream->As<ep::CudaStream>()->cuda_stream()));
return; return;
} }
OF_CUDA_CHECK(hipMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, hipMemcpyDefault, OF_CUDA_CHECK(hipMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, hipMemcpyDefault,
cuda_stream->cuda_stream())); cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync()); CHECK_JUST(cuda_stream->Sync());
*host_n_missing_ = 0; *host_n_missing_ = 0;
for (uint32_t i = 0; i < num_keys; ++i) { for (uint32_t i = 0; i < num_keys; ++i) {
auto it = store_.find(host_query_keys_[i]); auto it = store_.find(host_query_keys_[i]);
if (it != store_.end()) { if (it != store_.end()) {
std::memcpy(host_query_values_ + i * value_size_, it->second.data(), value_size_); std::memcpy(host_query_values_ + i * value_size_, it->second.data(), value_size_);
} else { } else {
host_missing_indices_[*host_n_missing_] = i; host_missing_indices_[*host_n_missing_] = i;
*host_n_missing_ += 1; *host_n_missing_ += 1;
} }
} }
OF_CUDA_CHECK(hipMemcpyAsync(values, host_query_values_, num_keys * value_size_, OF_CUDA_CHECK(hipMemcpyAsync(values, host_query_values_, num_keys * value_size_,
hipMemcpyDefault, cuda_stream->cuda_stream())); hipMemcpyDefault, cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(n_missing, host_n_missing_, sizeof(uint32_t), hipMemcpyDefault, OF_CUDA_CHECK(hipMemcpyAsync(n_missing, host_n_missing_, sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream())); cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(missing_indices, host_missing_indices_, OF_CUDA_CHECK(hipMemcpyAsync(missing_indices, host_missing_indices_,
(*host_n_missing_) * sizeof(uint32_t), hipMemcpyDefault, (*host_n_missing_) * sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream())); cuda_stream->cuda_stream()));
} }
template<typename Key> template<typename Key>
void KeyValueStoreImpl<Key>::Put(ep::Stream* stream, uint32_t num_keys, const void* keys, void KeyValueStoreImpl<Key>::Put(ep::Stream* stream, uint32_t num_keys, const void* keys,
const void* values) { const void* values) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
auto cuda_stream = stream->As<ep::CudaStream>(); auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_LE(num_keys, max_query_length_); CHECK_LE(num_keys, max_query_length_);
if (num_keys == 0) { return; } if (num_keys == 0) { return; }
OF_CUDA_CHECK(hipMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, hipMemcpyDefault, OF_CUDA_CHECK(hipMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, hipMemcpyDefault,
cuda_stream->cuda_stream())); cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(host_query_values_, values, value_size_ * num_keys, OF_CUDA_CHECK(hipMemcpyAsync(host_query_values_, values, value_size_ * num_keys,
hipMemcpyDefault, cuda_stream->cuda_stream())); hipMemcpyDefault, cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync()); CHECK_JUST(cuda_stream->Sync());
for (uint32_t i = 0; i < num_keys; ++i) { for (uint32_t i = 0; i < num_keys; ++i) {
store_[host_query_keys_[i]] = std::string( store_[host_query_keys_[i]] = std::string(
reinterpret_cast<const char*>(host_query_values_) + i * value_size_, value_size_); reinterpret_cast<const char*>(host_query_values_) + i * value_size_, value_size_);
} }
} }
template<typename Key> template<typename Key>
bool KeyValueStoreImpl<Key>::SnapshotExists(const std::string& name) { bool KeyValueStoreImpl<Key>::SnapshotExists(const std::string& name) {
return snapshots_.find(name) != snapshots_.end(); return snapshots_.find(name) != snapshots_.end();
} }
template<typename Key> template<typename Key>
void KeyValueStoreImpl<Key>::LoadSnapshot(const std::string& name) { void KeyValueStoreImpl<Key>::LoadSnapshot(const std::string& name) {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
LoadSnapshot(name, nullptr); LoadSnapshot(name, nullptr);
} }
template<typename Key> template<typename Key>
void KeyValueStoreImpl<Key>::LoadSnapshot(const std::string& name, void KeyValueStoreImpl<Key>::LoadSnapshot(const std::string& name,
const std::function<void(KVIterator* iter)>& Hook) { const std::function<void(KVIterator* iter)>& Hook) {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
store_ = snapshots_[name]; store_ = snapshots_[name];
if (Hook) { if (Hook) {
IteratorImpl<Key> iterator(&store_, KeySize(), ValueSize(), max_query_length_, host_query_keys_, IteratorImpl<Key> iterator(&store_, KeySize(), ValueSize(), max_query_length_, host_query_keys_,
host_query_values_, host_n_missing_); host_query_values_, host_n_missing_);
Hook(&iterator); Hook(&iterator);
} }
} }
template<typename Key> template<typename Key>
void KeyValueStoreImpl<Key>::SaveSnapshot(const std::string& name) { void KeyValueStoreImpl<Key>::SaveSnapshot(const std::string& name) {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
snapshots_[name] = store_; snapshots_[name] = store_;
} }
} // namespace } // namespace
std::unique_ptr<KeyValueStore> NewMockKeyValueStore(const MockKeyValueStoreOptions& options) { std::unique_ptr<KeyValueStore> NewMockKeyValueStore(const MockKeyValueStoreOptions& options) {
if (options.key_size == sizeof(uint64_t)) { if (options.key_size == sizeof(uint64_t)) {
return std::unique_ptr<KeyValueStore>(new KeyValueStoreImpl<uint64_t>(options)); return std::unique_ptr<KeyValueStore>(new KeyValueStoreImpl<uint64_t>(options));
} else if (options.key_size == sizeof(uint32_t)) { } else if (options.key_size == sizeof(uint32_t)) {
return std::unique_ptr<KeyValueStore>(new KeyValueStoreImpl<uint32_t>(options)); return std::unique_ptr<KeyValueStore>(new KeyValueStoreImpl<uint32_t>(options));
} else { } else {
UNIMPLEMENTED(); UNIMPLEMENTED();
return nullptr; return nullptr;
} }
} }
} // namespace embedding } // namespace embedding
} // namespace oneflow } // namespace oneflow
\ No newline at end of file
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/embedding/persistent_table_key_value_store.h" #include "oneflow/core/embedding/persistent_table_key_value_store.h"
#include "oneflow/core/device/cuda_util.h" #include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/embedding/persistent_table.h" #include "oneflow/core/embedding/persistent_table.h"
#include <robin_hood.h> #include <robin_hood.h>
#include <fcntl.h> #include <fcntl.h>
#include <sys/mman.h> #include <sys/mman.h>
#include <sys/stat.h> #include <sys/stat.h>
#include <dirent.h> #include <dirent.h>
namespace oneflow { namespace oneflow {
namespace embedding { namespace embedding {
namespace { namespace {
class IteratorImpl : public KVIterator { class IteratorImpl : public KVIterator {
public: public:
OF_DISALLOW_COPY_AND_MOVE(IteratorImpl); OF_DISALLOW_COPY_AND_MOVE(IteratorImpl);
IteratorImpl(PersistentTable::Iterator* base_iter, uint32_t key_size, uint32_t value_size, IteratorImpl(PersistentTable::Iterator* base_iter, uint32_t key_size, uint32_t value_size,
uint32_t max_query_length, void* host_keys_buffer, void* host_values_buffer, uint32_t max_query_length, void* host_keys_buffer, void* host_values_buffer,
uint32_t* host_num_buffer) uint32_t* host_num_buffer)
: base_iter_(base_iter), : base_iter_(base_iter),
key_size_(key_size), key_size_(key_size),
value_size_(value_size), value_size_(value_size),
max_query_length_(max_query_length), max_query_length_(max_query_length),
host_keys_buffer_(host_keys_buffer), host_keys_buffer_(host_keys_buffer),
host_values_buffer_(host_values_buffer), host_values_buffer_(host_values_buffer),
host_num_buffer_(host_num_buffer) {} host_num_buffer_(host_num_buffer) {}
~IteratorImpl() override = default; ~IteratorImpl() override = default;
void NextN(ep::Stream* stream, uint32_t n_request, uint32_t* n_result, void* keys, void NextN(ep::Stream* stream, uint32_t n_request, uint32_t* n_result, void* keys,
void* values) override { void* values) override {
CHECK_LE(n_request, max_query_length_); CHECK_LE(n_request, max_query_length_);
auto cuda_stream = stream->As<ep::CudaStream>(); auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_JUST(cuda_stream->Sync()); CHECK_JUST(cuda_stream->Sync());
base_iter_->Next(n_request, host_num_buffer_, host_keys_buffer_, host_values_buffer_); base_iter_->Next(n_request, host_num_buffer_, host_keys_buffer_, host_values_buffer_);
OF_CUDA_CHECK(hipMemcpyAsync(n_result, host_num_buffer_, sizeof(uint32_t), hipMemcpyDefault, OF_CUDA_CHECK(hipMemcpyAsync(n_result, host_num_buffer_, sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream())); cuda_stream->cuda_stream()));
const uint32_t num_keys = *host_num_buffer_; const uint32_t num_keys = *host_num_buffer_;
if (num_keys != 0) { if (num_keys != 0) {
OF_CUDA_CHECK(hipMemcpyAsync(keys, host_keys_buffer_, num_keys * key_size_, OF_CUDA_CHECK(hipMemcpyAsync(keys, host_keys_buffer_, num_keys * key_size_,
hipMemcpyDefault, cuda_stream->cuda_stream())); hipMemcpyDefault, cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(values, host_values_buffer_, num_keys * value_size_, OF_CUDA_CHECK(hipMemcpyAsync(values, host_values_buffer_, num_keys * value_size_,
hipMemcpyDefault, cuda_stream->cuda_stream())); hipMemcpyDefault, cuda_stream->cuda_stream()));
} }
} }
void Reset() override { base_iter_->Reset(); } void Reset() override { base_iter_->Reset(); }
private: private:
PersistentTable::Iterator* base_iter_; PersistentTable::Iterator* base_iter_;
uint32_t key_size_; uint32_t key_size_;
uint32_t value_size_; uint32_t value_size_;
uint32_t max_query_length_; uint32_t max_query_length_;
void* host_keys_buffer_; void* host_keys_buffer_;
void* host_values_buffer_; void* host_values_buffer_;
uint32_t* host_num_buffer_; uint32_t* host_num_buffer_;
}; };
template<typename Key> template<typename Key>
class KeyValueStoreImpl : public KeyValueStore { class KeyValueStoreImpl : public KeyValueStore {
public: public:
OF_DISALLOW_COPY_AND_MOVE(KeyValueStoreImpl); OF_DISALLOW_COPY_AND_MOVE(KeyValueStoreImpl);
explicit KeyValueStoreImpl(const PersistentTableKeyValueStoreOptions& options) explicit KeyValueStoreImpl(const PersistentTableKeyValueStoreOptions& options)
: device_index_(-1), max_query_length_(0) { : device_index_(-1), max_query_length_(0) {
OF_CUDA_CHECK(hipGetDevice(&device_index_)); OF_CUDA_CHECK(hipGetDevice(&device_index_));
key_size_ = options.table_options.key_size; key_size_ = options.table_options.key_size;
value_size_ = options.table_options.value_size; value_size_ = options.table_options.value_size;
table_ = NewPersistentTable(options.table_options); table_ = NewPersistentTable(options.table_options);
OF_CUDA_CHECK(NumaAwareCudaMallocHost( OF_CUDA_CHECK(NumaAwareCudaMallocHost(
device_index_, reinterpret_cast<void**>(&host_query_keys_), key_size_ * max_query_length_)); device_index_, reinterpret_cast<void**>(&host_query_keys_), key_size_ * max_query_length_));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,
reinterpret_cast<void**>(&host_query_values_), reinterpret_cast<void**>(&host_query_values_),
value_size_ * max_query_length_)); value_size_ * max_query_length_));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, reinterpret_cast<void**>(&host_n_missing_), OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, reinterpret_cast<void**>(&host_n_missing_),
sizeof(uint32_t))); sizeof(uint32_t)));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,
reinterpret_cast<void**>(&host_missing_indices_), reinterpret_cast<void**>(&host_missing_indices_),
sizeof(uint32_t) * max_query_length_)); sizeof(uint32_t) * max_query_length_));
} }
~KeyValueStoreImpl() { ~KeyValueStoreImpl() {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
if (max_query_length_ != 0) { if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipHostFree(host_query_keys_)); OF_CUDA_CHECK(hipHostFree(host_query_keys_));
OF_CUDA_CHECK(hipHostFree(host_query_values_)); OF_CUDA_CHECK(hipHostFree(host_query_values_));
OF_CUDA_CHECK(hipHostFree(host_missing_indices_)); OF_CUDA_CHECK(hipHostFree(host_missing_indices_));
} }
OF_CUDA_CHECK(hipHostFree(host_n_missing_)); OF_CUDA_CHECK(hipHostFree(host_n_missing_));
} }
uint32_t KeySize() const override { return key_size_; } uint32_t KeySize() const override { return key_size_; }
uint32_t ValueSize() const override { return value_size_; } uint32_t ValueSize() const override { return value_size_; }
uint32_t MaxQueryLength() const override { return max_query_length_; } uint32_t MaxQueryLength() const override { return max_query_length_; }
void ReserveQueryLength(uint32_t query_length) override { void ReserveQueryLength(uint32_t query_length) override {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
if (query_length <= max_query_length_) { return; } if (query_length <= max_query_length_) { return; }
if (max_query_length_ != 0) { if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipHostFree(host_query_keys_)); OF_CUDA_CHECK(hipHostFree(host_query_keys_));
OF_CUDA_CHECK(hipHostFree(host_query_values_)); OF_CUDA_CHECK(hipHostFree(host_query_values_));
OF_CUDA_CHECK(hipHostFree(host_missing_indices_)); OF_CUDA_CHECK(hipHostFree(host_missing_indices_));
} }
OF_CUDA_CHECK(NumaAwareCudaMallocHost( OF_CUDA_CHECK(NumaAwareCudaMallocHost(
device_index_, reinterpret_cast<void**>(&host_query_keys_), key_size_ * query_length)); device_index_, reinterpret_cast<void**>(&host_query_keys_), key_size_ * query_length));
OF_CUDA_CHECK(NumaAwareCudaMallocHost( OF_CUDA_CHECK(NumaAwareCudaMallocHost(
device_index_, reinterpret_cast<void**>(&host_query_values_), value_size_ * query_length)); device_index_, reinterpret_cast<void**>(&host_query_values_), value_size_ * query_length));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,
reinterpret_cast<void**>(&host_missing_indices_), reinterpret_cast<void**>(&host_missing_indices_),
sizeof(uint32_t) * query_length)); sizeof(uint32_t) * query_length));
max_query_length_ = query_length; max_query_length_ = query_length;
} }
void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values, void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values,
uint32_t* n_missing, uint32_t* missing_indices) override; uint32_t* n_missing, uint32_t* missing_indices) override;
void Put(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values) override; void Put(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values) override;
bool SnapshotExists(const std::string& name) override; bool SnapshotExists(const std::string& name) override;
void LoadSnapshot(const std::string& name) override; void LoadSnapshot(const std::string& name) override;
void LoadSnapshot(const std::string& name, void LoadSnapshot(const std::string& name,
const std::function<void(KVIterator* iter)>& Hook) override; const std::function<void(KVIterator* iter)>& Hook) override;
void SaveSnapshot(const std::string& name) override; void SaveSnapshot(const std::string& name) override;
private: private:
int device_index_; int device_index_;
uint32_t max_query_length_; uint32_t max_query_length_;
uint32_t key_size_; uint32_t key_size_;
uint32_t value_size_; uint32_t value_size_;
Key* host_query_keys_{}; Key* host_query_keys_{};
uint8_t* host_query_values_{}; uint8_t* host_query_values_{};
uint32_t* host_n_missing_{}; uint32_t* host_n_missing_{};
uint32_t* host_missing_indices_{}; uint32_t* host_missing_indices_{};
std::mutex mutex_; std::mutex mutex_;
std::unique_ptr<PersistentTable> table_; std::unique_ptr<PersistentTable> table_;
}; };
template<typename Key> template<typename Key>
void KeyValueStoreImpl<Key>::Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void KeyValueStoreImpl<Key>::Get(ep::Stream* stream, uint32_t num_keys, const void* keys,
void* values, uint32_t* n_missing, uint32_t* missing_indices) { void* values, uint32_t* n_missing, uint32_t* missing_indices) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
auto cuda_stream = stream->As<ep::CudaStream>(); auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_LE(num_keys, max_query_length_); CHECK_LE(num_keys, max_query_length_);
if (num_keys == 0) { if (num_keys == 0) {
OF_CUDA_CHECK(hipMemsetAsync(n_missing, 0, sizeof(uint32_t), OF_CUDA_CHECK(hipMemsetAsync(n_missing, 0, sizeof(uint32_t),
stream->As<ep::CudaStream>()->cuda_stream())); stream->As<ep::CudaStream>()->cuda_stream()));
return; return;
} }
OF_CUDA_CHECK(hipMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, hipMemcpyDefault, OF_CUDA_CHECK(hipMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, hipMemcpyDefault,
cuda_stream->cuda_stream())); cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync()); CHECK_JUST(cuda_stream->Sync());
table_->Get(num_keys, host_query_keys_, host_query_values_, host_n_missing_, table_->Get(num_keys, host_query_keys_, host_query_values_, host_n_missing_,
host_missing_indices_); host_missing_indices_);
OF_CUDA_CHECK(hipMemcpyAsync(values, host_query_values_, num_keys * value_size_, OF_CUDA_CHECK(hipMemcpyAsync(values, host_query_values_, num_keys * value_size_,
hipMemcpyDefault, cuda_stream->cuda_stream())); hipMemcpyDefault, cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(n_missing, host_n_missing_, sizeof(uint32_t), hipMemcpyDefault, OF_CUDA_CHECK(hipMemcpyAsync(n_missing, host_n_missing_, sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream())); cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(missing_indices, host_missing_indices_, OF_CUDA_CHECK(hipMemcpyAsync(missing_indices, host_missing_indices_,
(*host_n_missing_) * sizeof(uint32_t), hipMemcpyDefault, (*host_n_missing_) * sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream())); cuda_stream->cuda_stream()));
} }
template<typename Key> template<typename Key>
void KeyValueStoreImpl<Key>::Put(ep::Stream* stream, uint32_t num_keys, const void* keys, void KeyValueStoreImpl<Key>::Put(ep::Stream* stream, uint32_t num_keys, const void* keys,
const void* values) { const void* values) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
auto cuda_stream = stream->As<ep::CudaStream>(); auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_LE(num_keys, max_query_length_); CHECK_LE(num_keys, max_query_length_);
if (num_keys == 0) { return; } if (num_keys == 0) { return; }
OF_CUDA_CHECK(hipMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, hipMemcpyDefault, OF_CUDA_CHECK(hipMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, hipMemcpyDefault,
cuda_stream->cuda_stream())); cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(host_query_values_, values, value_size_ * num_keys, OF_CUDA_CHECK(hipMemcpyAsync(host_query_values_, values, value_size_ * num_keys,
hipMemcpyDefault, cuda_stream->cuda_stream())); hipMemcpyDefault, cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync()); CHECK_JUST(cuda_stream->Sync());
table_->Put(num_keys, host_query_keys_, host_query_values_); table_->Put(num_keys, host_query_keys_, host_query_values_);
} }
template<typename Key> template<typename Key>
bool KeyValueStoreImpl<Key>::SnapshotExists(const std::string& name) { bool KeyValueStoreImpl<Key>::SnapshotExists(const std::string& name) {
return table_->SnapshotExists(name); return table_->SnapshotExists(name);
} }
template<typename Key> template<typename Key>
void KeyValueStoreImpl<Key>::LoadSnapshot(const std::string& name) { void KeyValueStoreImpl<Key>::LoadSnapshot(const std::string& name) {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
LoadSnapshot(name, nullptr); LoadSnapshot(name, nullptr);
} }
template<typename Key> template<typename Key>
void KeyValueStoreImpl<Key>::LoadSnapshot(const std::string& name, void KeyValueStoreImpl<Key>::LoadSnapshot(const std::string& name,
const std::function<void(KVIterator* iter)>& Hook) { const std::function<void(KVIterator* iter)>& Hook) {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
if (Hook) { if (Hook) {
table_->LoadSnapshot(name, [&](PersistentTable::Iterator* chunk_iterator) { table_->LoadSnapshot(name, [&](PersistentTable::Iterator* chunk_iterator) {
IteratorImpl iterator(chunk_iterator, KeySize(), ValueSize(), max_query_length_, IteratorImpl iterator(chunk_iterator, KeySize(), ValueSize(), max_query_length_,
host_query_keys_, host_query_values_, host_n_missing_); host_query_keys_, host_query_values_, host_n_missing_);
Hook(&iterator); Hook(&iterator);
}); });
} else { } else {
table_->LoadSnapshot(name); table_->LoadSnapshot(name);
} }
} }
template<typename Key> template<typename Key>
void KeyValueStoreImpl<Key>::SaveSnapshot(const std::string& name) { void KeyValueStoreImpl<Key>::SaveSnapshot(const std::string& name) {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
table_->SaveSnapshot(name); table_->SaveSnapshot(name);
} }
} // namespace } // namespace
std::unique_ptr<KeyValueStore> NewPersistentTableKeyValueStore( std::unique_ptr<KeyValueStore> NewPersistentTableKeyValueStore(
const PersistentTableKeyValueStoreOptions& options) { const PersistentTableKeyValueStoreOptions& options) {
if (options.table_options.key_size == sizeof(uint64_t)) { if (options.table_options.key_size == sizeof(uint64_t)) {
return std::unique_ptr<KeyValueStore>(new KeyValueStoreImpl<uint64_t>(options)); return std::unique_ptr<KeyValueStore>(new KeyValueStoreImpl<uint64_t>(options));
} else if (options.table_options.key_size == sizeof(uint32_t)) { } else if (options.table_options.key_size == sizeof(uint32_t)) {
return std::unique_ptr<KeyValueStore>(new KeyValueStoreImpl<uint32_t>(options)); return std::unique_ptr<KeyValueStore>(new KeyValueStoreImpl<uint32_t>(options));
} else { } else {
UNIMPLEMENTED(); UNIMPLEMENTED();
return nullptr; return nullptr;
} }
} }
} // namespace embedding } // namespace embedding
} // namespace oneflow } // namespace oneflow
\ No newline at end of file
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/ep/rocm/cuda_device.h" #include "oneflow/core/ep/rocm/cuda_device.h"
#include "oneflow/core/ep/rocm/cuda_event.h" #include "oneflow/core/ep/rocm/cuda_event.h"
#include "oneflow/core/ep/rocm/cuda_stream.h" #include "oneflow/core/ep/rocm/cuda_stream.h"
#ifdef WITH_ROCM #ifdef WITH_ROCM
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
// #if CUDA_VERSION >= 11000 // #if CUDA_VERSION >= 11000
// #include <cuda_bf16.h> // #include <cuda_bf16.h>
// #endif // #endif
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
namespace { namespace {
constexpr size_t kDefaultConstBufElementCount = 1024 * 1024; constexpr size_t kDefaultConstBufElementCount = 1024 * 1024;
template<typename T> template<typename T>
void CreateConstBuffer(void** buf, T value, size_t n) { void CreateConstBuffer(void** buf, T value, size_t n) {
OF_CUDA_CHECK(hipMalloc(buf, n * sizeof(T))); OF_CUDA_CHECK(hipMalloc(buf, n * sizeof(T)));
std::vector<T> host(n, value); std::vector<T> host(n, value);
OF_CUDA_CHECK(hipMemcpy(*buf, host.data(), n * sizeof(T), hipMemcpyDefault)); OF_CUDA_CHECK(hipMemcpy(*buf, host.data(), n * sizeof(T), hipMemcpyDefault));
} }
} // namespace } // namespace
CudaDevice::CudaDevice(int device_index, DeviceManager* device_manager) CudaDevice::CudaDevice(int device_index, DeviceManager* device_manager)
: device_index_(device_index), : device_index_(device_index),
event_flags_{}, event_flags_{},
properties_{}, properties_{},
device_manager_(device_manager), device_manager_(device_manager),
const_buf_elem_cnt_(0), const_buf_elem_cnt_(0),
const_zeros_buffer_(nullptr), const_zeros_buffer_(nullptr),
const_ones_buffer_fp32_(nullptr), const_ones_buffer_fp32_(nullptr),
const_ones_buffer_fp16_(nullptr), const_ones_buffer_fp16_(nullptr),
const_ones_buffer_bf16_(nullptr) { const_ones_buffer_bf16_(nullptr) {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
OF_CUDA_CHECK(hipGetDeviceProperties(&properties_, device_index_)); OF_CUDA_CHECK(hipGetDeviceProperties(&properties_, device_index_));
event_flags_ = hipEventDisableTiming; event_flags_ = hipEventDisableTiming;
if (ParseBooleanFromEnv("ONEFLOW_STREAM_CUDA_EVENT_FLAG_BLOCKING_SYNC", false)) { if (ParseBooleanFromEnv("ONEFLOW_STREAM_CUDA_EVENT_FLAG_BLOCKING_SYNC", false)) {
event_flags_ |= hipEventBlockingSync; event_flags_ |= hipEventBlockingSync;
} }
const_buf_elem_cnt_ = ParseIntegerFromEnv("ONEFLOW_EP_CUDA_CONST_BUFFER_ELEMENT_COUNT", const_buf_elem_cnt_ = ParseIntegerFromEnv("ONEFLOW_EP_CUDA_CONST_BUFFER_ELEMENT_COUNT",
kDefaultConstBufElementCount); kDefaultConstBufElementCount);
if (const_buf_elem_cnt_ > 0) { if (const_buf_elem_cnt_ > 0) {
CreateConstBuffer<float>(&const_zeros_buffer_, static_cast<float>(0), const_buf_elem_cnt_); CreateConstBuffer<float>(&const_zeros_buffer_, static_cast<float>(0), const_buf_elem_cnt_);
CreateConstBuffer<float>(&const_ones_buffer_fp32_, static_cast<float>(1.0), CreateConstBuffer<float>(&const_ones_buffer_fp32_, static_cast<float>(1.0),
const_buf_elem_cnt_); const_buf_elem_cnt_);
CreateConstBuffer<half>(&const_ones_buffer_fp16_, static_cast<half>(1.0), const_buf_elem_cnt_); CreateConstBuffer<half>(&const_ones_buffer_fp16_, static_cast<half>(1.0), const_buf_elem_cnt_);
// #if CUDA_VERSION >= 11000 // #if CUDA_VERSION >= 11000
// CreateConstBuffer<nv_bfloat16>(&const_ones_buffer_bf16_, static_cast<nv_bfloat16>(1.0), // CreateConstBuffer<nv_bfloat16>(&const_ones_buffer_bf16_, static_cast<nv_bfloat16>(1.0),
// const_buf_elem_cnt_); // const_buf_elem_cnt_);
// #endif // #endif
} }
} }
CudaDevice::~CudaDevice() { CudaDevice::~CudaDevice() {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
for (auto* event : events_) { delete event; } for (auto* event : events_) { delete event; }
OF_CUDA_CHECK(hipFree(const_zeros_buffer_)); OF_CUDA_CHECK(hipFree(const_zeros_buffer_));
OF_CUDA_CHECK(hipFree(const_ones_buffer_fp32_)); OF_CUDA_CHECK(hipFree(const_ones_buffer_fp32_));
OF_CUDA_CHECK(hipFree(const_ones_buffer_fp16_)); OF_CUDA_CHECK(hipFree(const_ones_buffer_fp16_));
OF_CUDA_CHECK(hipFree(const_ones_buffer_bf16_)); OF_CUDA_CHECK(hipFree(const_ones_buffer_bf16_));
} }
void CudaDevice::SetAsActiveDevice() { OF_CUDA_CHECK(hipSetDevice(device_index_)); } void CudaDevice::SetAsActiveDevice() { OF_CUDA_CHECK(hipSetDevice(device_index_)); }
Stream* CudaDevice::CreateStream() { Stream* CudaDevice::CreateStream() {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
return new CudaStream(this); return new CudaStream(this);
} }
void CudaDevice::DestroyStream(Stream* stream) { void CudaDevice::DestroyStream(Stream* stream) {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
delete stream; delete stream;
} }
void CudaDevice::CreateEvents(Event** events, size_t count) { void CudaDevice::CreateEvents(Event** events, size_t count) {
size_t copied = 0; size_t copied = 0;
{ {
std::lock_guard<std::mutex> lock(events_mutex_); std::lock_guard<std::mutex> lock(events_mutex_);
copied = std::min(count, events_.size()); copied = std::min(count, events_.size());
size_t offset = events_.size() - copied; size_t offset = events_.size() - copied;
std::copy(events_.begin() + offset, events_.end(), events); std::copy(events_.begin() + offset, events_.end(), events);
events_.resize(offset); events_.resize(offset);
} }
if (copied != count) { if (copied != count) {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
for (size_t i = copied; i < count; ++i) { events[i] = new CudaEvent(event_flags_); } for (size_t i = copied; i < count; ++i) { events[i] = new CudaEvent(event_flags_); }
} }
} }
void CudaDevice::DestroyEvents(Event** events, size_t count) { void CudaDevice::DestroyEvents(Event** events, size_t count) {
std::lock_guard<std::mutex> lock(events_mutex_); std::lock_guard<std::mutex> lock(events_mutex_);
events_.insert(events_.end(), events, events + count); events_.insert(events_.end(), events, events + count);
} }
Maybe<void> CudaDevice::Alloc(const AllocationOptions& options, void** ptr, size_t size) { Maybe<void> CudaDevice::Alloc(const AllocationOptions& options, void** ptr, size_t size) {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
CHECK(!options.HasPinnedDevice()); CHECK(!options.HasPinnedDevice());
hipError_t err = hipMalloc(ptr, size); hipError_t err = hipMalloc(ptr, size);
if (err != hipSuccess) { if (err != hipSuccess) {
return Error::RuntimeError() << hipGetErrorString(err); return Error::RuntimeError() << hipGetErrorString(err);
} else { } else {
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
} }
void CudaDevice::Free(const AllocationOptions& attr, void* ptr) { void CudaDevice::Free(const AllocationOptions& attr, void* ptr) {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
OF_CUDA_CHECK(hipFree(ptr)); OF_CUDA_CHECK(hipFree(ptr));
} }
Maybe<void> CudaDevice::AllocPinned(const AllocationOptions& options, void** ptr, size_t size) { Maybe<void> CudaDevice::AllocPinned(const AllocationOptions& options, void** ptr, size_t size) {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
hipError_t err = NumaAwareCudaMallocHost(device_index_, ptr, size); hipError_t err = NumaAwareCudaMallocHost(device_index_, ptr, size);
if (err != hipSuccess) { if (err != hipSuccess) {
return Error::RuntimeError() << hipGetErrorString(err); return Error::RuntimeError() << hipGetErrorString(err);
} else { } else {
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
} }
void CudaDevice::FreePinned(const AllocationOptions& options, void* ptr) { void CudaDevice::FreePinned(const AllocationOptions& options, void* ptr) {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
OF_CUDA_CHECK(hipHostFree(ptr)); OF_CUDA_CHECK(hipHostFree(ptr));
} }
const hipDeviceProp_t& CudaDevice::properties() const { return properties_; } const hipDeviceProp_t& CudaDevice::properties() const { return properties_; }
const void* CudaDevice::GetConstZeros(DataType data_type, size_t n) const { const void* CudaDevice::GetConstZeros(DataType data_type, size_t n) const {
if (GetSizeOfDataType(data_type) * n if (GetSizeOfDataType(data_type) * n
<= GetSizeOfDataType(DataType::kFloat) * const_buf_elem_cnt_) { <= GetSizeOfDataType(DataType::kFloat) * const_buf_elem_cnt_) {
return const_zeros_buffer_; return const_zeros_buffer_;
} else { } else {
return nullptr; return nullptr;
} }
} }
const void* CudaDevice::GetConstOnes(DataType data_type, size_t n) const { const void* CudaDevice::GetConstOnes(DataType data_type, size_t n) const {
if (n <= const_buf_elem_cnt_) { if (n <= const_buf_elem_cnt_) {
if (data_type == DataType::kFloat) { if (data_type == DataType::kFloat) {
return const_ones_buffer_fp32_; return const_ones_buffer_fp32_;
} else if (data_type == DataType::kFloat16) { } else if (data_type == DataType::kFloat16) {
return const_ones_buffer_fp16_; return const_ones_buffer_fp16_;
} else if (data_type == DataType::kBFloat16) { } else if (data_type == DataType::kBFloat16) {
return const_ones_buffer_bf16_; return const_ones_buffer_bf16_;
} else { } else {
return nullptr; return nullptr;
} }
} else { } else {
return nullptr; return nullptr;
} }
} }
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
#endif // WITH_ROCM #endif // WITH_ROCM
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#ifndef ONEFLOW_CORE_EP_ROCM_CUDA_DEVICE_H_ #ifndef ONEFLOW_CORE_EP_ROCM_CUDA_DEVICE_H_
#define ONEFLOW_CORE_EP_ROCM_CUDA_DEVICE_H_ #define ONEFLOW_CORE_EP_ROCM_CUDA_DEVICE_H_
#include "oneflow/core/ep/include/device.h" #include "oneflow/core/ep/include/device.h"
#include "oneflow/core/common/data_type.h" #include "oneflow/core/common/data_type.h"
#ifdef WITH_ROCM #ifdef WITH_ROCM
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
class CudaDevice : public Device { class CudaDevice : public Device {
public: public:
OF_DISALLOW_COPY_AND_MOVE(CudaDevice); OF_DISALLOW_COPY_AND_MOVE(CudaDevice);
explicit CudaDevice(int device_index, DeviceManager* device_manager); explicit CudaDevice(int device_index, DeviceManager* device_manager);
~CudaDevice() override; ~CudaDevice() override;
void SetAsActiveDevice() override; void SetAsActiveDevice() override;
DeviceType device_type() const override { return DeviceType::kCUDA; } DeviceType device_type() const override { return DeviceType::kCUDA; }
size_t device_index() const override { return device_index_; } size_t device_index() const override { return device_index_; }
DeviceManager* device_manager() const override { return device_manager_; } DeviceManager* device_manager() const override { return device_manager_; }
Stream* CreateStream() override; Stream* CreateStream() override;
void DestroyStream(Stream* stream) override; void DestroyStream(Stream* stream) override;
void CreateEvents(Event** events, size_t count) override; void CreateEvents(Event** events, size_t count) override;
void DestroyEvents(Event** events, size_t count) override; void DestroyEvents(Event** events, size_t count) override;
Maybe<void> Alloc(const AllocationOptions& options, void** ptr, size_t size) override; Maybe<void> Alloc(const AllocationOptions& options, void** ptr, size_t size) override;
void Free(const AllocationOptions& options, void* ptr) override; void Free(const AllocationOptions& options, void* ptr) override;
Maybe<void> AllocPinned(const AllocationOptions& options, void** ptr, size_t size) override; Maybe<void> AllocPinned(const AllocationOptions& options, void** ptr, size_t size) override;
void FreePinned(const AllocationOptions& options, void* ptr) override; void FreePinned(const AllocationOptions& options, void* ptr) override;
const hipDeviceProp_t& properties() const; const hipDeviceProp_t& properties() const;
const void* GetConstZeros(DataType data_type, size_t n) const; const void* GetConstZeros(DataType data_type, size_t n) const;
const void* GetConstOnes(DataType data_type, size_t n) const; const void* GetConstOnes(DataType data_type, size_t n) const;
private: private:
int device_index_; int device_index_;
std::mutex events_mutex_; std::mutex events_mutex_;
std::vector<Event*> events_; std::vector<Event*> events_;
unsigned int event_flags_; unsigned int event_flags_;
hipDeviceProp_t properties_; hipDeviceProp_t properties_;
DeviceManager* device_manager_; DeviceManager* device_manager_;
int64_t const_buf_elem_cnt_; int64_t const_buf_elem_cnt_;
void* const_zeros_buffer_; void* const_zeros_buffer_;
void* const_ones_buffer_fp32_; void* const_ones_buffer_fp32_;
void* const_ones_buffer_fp16_; void* const_ones_buffer_fp16_;
void* const_ones_buffer_bf16_; void* const_ones_buffer_bf16_;
}; };
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
#endif // WITH_ROCM #endif // WITH_ROCM
#endif // ONEFLOW_CORE_EP_ROCM_CUDA_DEVICE_H_ #endif // ONEFLOW_CORE_EP_ROCM_CUDA_DEVICE_H_
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/ep/rocm/cuda_device_manager.h" #include "oneflow/core/ep/rocm/cuda_device_manager.h"
#include "oneflow/core/device/cuda_util.h" #include "oneflow/core/device/cuda_util.h"
#ifdef WITH_ROCM #ifdef WITH_ROCM
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
CudaDeviceManager::CudaDeviceManager(DeviceManagerRegistry* registry) : registry_(registry) {} CudaDeviceManager::CudaDeviceManager(DeviceManagerRegistry* registry) : registry_(registry) {}
CudaDeviceManager::~CudaDeviceManager() = default; CudaDeviceManager::~CudaDeviceManager() = default;
DeviceManagerRegistry* CudaDeviceManager::registry() const { return registry_; } DeviceManagerRegistry* CudaDeviceManager::registry() const { return registry_; }
std::shared_ptr<Device> CudaDeviceManager::GetDevice(size_t device_index) { std::shared_ptr<Device> CudaDeviceManager::GetDevice(size_t device_index) {
std::lock_guard<std::mutex> lock(devices_mutex_); std::lock_guard<std::mutex> lock(devices_mutex_);
if (device_index < devices_.size() && devices_.at(device_index)) { if (device_index < devices_.size() && devices_.at(device_index)) {
return devices_.at(device_index); return devices_.at(device_index);
} }
auto device = std::make_shared<CudaDevice>(device_index, this); auto device = std::make_shared<CudaDevice>(device_index, this);
if (device_index >= devices_.size()) { devices_.resize(device_index + 1); } if (device_index >= devices_.size()) { devices_.resize(device_index + 1); }
devices_.at(device_index) = device; devices_.at(device_index) = device;
return device; return device;
} }
size_t CudaDeviceManager::GetDeviceCount(size_t primary_device_index) { size_t CudaDeviceManager::GetDeviceCount(size_t primary_device_index) {
CudaCurrentDeviceGuard guard(primary_device_index); CudaCurrentDeviceGuard guard(primary_device_index);
return this->GetDeviceCount(); return this->GetDeviceCount();
} }
size_t CudaDeviceManager::GetDeviceCount() { size_t CudaDeviceManager::GetDeviceCount() {
int count = 0; int count = 0;
hipError_t err = hipGetDeviceCount(&count); hipError_t err = hipGetDeviceCount(&count);
if (err == hipErrorNoDevice || err == hipErrorInsufficientDriver) { return 0; } if (err == hipErrorNoDevice || err == hipErrorInsufficientDriver) { return 0; }
OF_CUDA_CHECK(err); OF_CUDA_CHECK(err);
return count; return count;
} }
size_t CudaDeviceManager::GetActiveDeviceIndex() { size_t CudaDeviceManager::GetActiveDeviceIndex() {
int device = 0; int device = 0;
OF_CUDA_CHECK(hipGetDevice(&device)); OF_CUDA_CHECK(hipGetDevice(&device));
return static_cast<size_t>(device); return static_cast<size_t>(device);
} }
void CudaDeviceManager::SetActiveDeviceByIndex(size_t device_index) { void CudaDeviceManager::SetActiveDeviceByIndex(size_t device_index) {
OF_CUDA_CHECK(hipSetDevice(static_cast<int>(device_index))); OF_CUDA_CHECK(hipSetDevice(static_cast<int>(device_index)));
} }
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
#endif // WITH_ROCM #endif // WITH_ROCM
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#ifndef ONEFLOW_CORE_EP_ROCM_CUDA_DEVICE_MANAGER_H_ #ifndef ONEFLOW_CORE_EP_ROCM_CUDA_DEVICE_MANAGER_H_
#define ONEFLOW_CORE_EP_ROCM_CUDA_DEVICE_MANAGER_H_ #define ONEFLOW_CORE_EP_ROCM_CUDA_DEVICE_MANAGER_H_
#include "oneflow/core/ep/include/device_manager.h" #include "oneflow/core/ep/include/device_manager.h"
#include "oneflow/core/ep/rocm/cuda_device.h" #include "oneflow/core/ep/rocm/cuda_device.h"
#ifdef WITH_ROCM #ifdef WITH_ROCM
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
class CudaDevice; class CudaDevice;
class CudaDeviceManager : public DeviceManager { class CudaDeviceManager : public DeviceManager {
public: public:
OF_DISALLOW_COPY_AND_MOVE(CudaDeviceManager); OF_DISALLOW_COPY_AND_MOVE(CudaDeviceManager);
CudaDeviceManager(DeviceManagerRegistry* registry); CudaDeviceManager(DeviceManagerRegistry* registry);
~CudaDeviceManager() override; ~CudaDeviceManager() override;
DeviceManagerRegistry* registry() const override; DeviceManagerRegistry* registry() const override;
std::shared_ptr<Device> GetDevice(size_t device_index) override; std::shared_ptr<Device> GetDevice(size_t device_index) override;
size_t GetDeviceCount(size_t primary_device_index) override; size_t GetDeviceCount(size_t primary_device_index) override;
size_t GetDeviceCount() override; size_t GetDeviceCount() override;
size_t GetActiveDeviceIndex() override; size_t GetActiveDeviceIndex() override;
void SetActiveDeviceByIndex(size_t device_index) override; void SetActiveDeviceByIndex(size_t device_index) override;
private: private:
std::mutex devices_mutex_; std::mutex devices_mutex_;
std::vector<std::shared_ptr<CudaDevice>> devices_; std::vector<std::shared_ptr<CudaDevice>> devices_;
DeviceManagerRegistry* registry_; DeviceManagerRegistry* registry_;
}; };
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
#endif // WITH_ROCM #endif // WITH_ROCM
#endif // ONEFLOW_CORE_EP_ROCM_CUDA_DEVICE_MANAGER_H_ #endif // ONEFLOW_CORE_EP_ROCM_CUDA_DEVICE_MANAGER_H_
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/ep/include/device_manager_factory.h" #include "oneflow/core/ep/include/device_manager_factory.h"
#include "oneflow/core/ep/include/device_manager_registry.h" #include "oneflow/core/ep/include/device_manager_registry.h"
#include "oneflow/core/ep/rocm/cuda_device_manager.h" #include "oneflow/core/ep/rocm/cuda_device_manager.h"
#ifdef WITH_ROCM #ifdef WITH_ROCM
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <miopen/miopen.h> #include <miopen/miopen.h>
#include <rccl.h> #include <rccl.h>
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
namespace { namespace {
std::string GetCudaVersionString(int version) { std::string GetCudaVersionString(int version) {
return std::to_string(version / 1000) + "." + std::to_string((version % 1000) / 10); return std::to_string(version / 1000) + "." + std::to_string((version % 1000) / 10);
} }
bool GetCudnnVersion(size_t* major, size_t* minor, size_t* patch) { bool GetCudnnVersion(size_t* major, size_t* minor, size_t* patch) {
miopenStatus_t status = miopenGetVersion(major, minor, patch); miopenStatus_t status = miopenGetVersion(major, minor, patch);
if (status == miopenStatusSuccess) { if (status == miopenStatusSuccess) {
return true; return true;
} else { } else {
LOG(ERROR) << "Failed to get cuDNN version: " << miopenGetErrorString(status); LOG(ERROR) << "Failed to get cuDNN version: " << miopenGetErrorString(status);
return false; return false;
} }
} }
bool GetCudnnVersionString(std::string* version) { bool GetCudnnVersionString(std::string* version) {
size_t version_major = 0; size_t version_major = 0;
size_t version_minor = 0; size_t version_minor = 0;
size_t version_patch = 0; size_t version_patch = 0;
if (!GetCudnnVersion(&version_major, &version_minor, &version_patch)) { return false; } if (!GetCudnnVersion(&version_major, &version_minor, &version_patch)) { return false; }
*version = std::to_string(version_major) + "." + std::to_string(version_minor) + "." *version = std::to_string(version_major) + "." + std::to_string(version_minor) + "."
+ std::to_string(version_patch); + std::to_string(version_patch);
return true; return true;
} }
void CudaDumpVersionInfo() { void CudaDumpVersionInfo() {
{ {
int cuda_runtime_version = 0; int cuda_runtime_version = 0;
hipError_t err = hipRuntimeGetVersion(&cuda_runtime_version); hipError_t err = hipRuntimeGetVersion(&cuda_runtime_version);
if (err == hipSuccess) { if (err == hipSuccess) {
LOG(INFO) << "CUDA runtime version: " << GetCudaVersionString(cuda_runtime_version); LOG(INFO) << "CUDA runtime version: " << GetCudaVersionString(cuda_runtime_version);
} else { } else {
LOG(ERROR) << "Failed to get cuda runtime version: " << hipGetErrorString(err); LOG(ERROR) << "Failed to get cuda runtime version: " << hipGetErrorString(err);
} }
} }
{ {
std::string cudnn_version_string; std::string cudnn_version_string;
if (GetCudnnVersionString(&cudnn_version_string)) { if (GetCudnnVersionString(&cudnn_version_string)) {
LOG(INFO) << "cuDNN version: " << cudnn_version_string; LOG(INFO) << "cuDNN version: " << cudnn_version_string;
} }
} }
{ {
int nccl_version = 0; int nccl_version = 0;
ncclResult_t result = ncclGetVersion(&nccl_version); ncclResult_t result = ncclGetVersion(&nccl_version);
if (result == ncclSuccess) { if (result == ncclSuccess) {
int nccl_version_major = int nccl_version_major =
(nccl_version >= 20900) ? (nccl_version / 10000) : (nccl_version / 1000); (nccl_version >= 20900) ? (nccl_version / 10000) : (nccl_version / 1000);
int nccl_version_minor = int nccl_version_minor =
(nccl_version >= 20900) ? (nccl_version % 10000) / 100 : (nccl_version % 1000) / 100; (nccl_version >= 20900) ? (nccl_version % 10000) / 100 : (nccl_version % 1000) / 100;
int nccl_version_patch = (nccl_version % 100); int nccl_version_patch = (nccl_version % 100);
LOG(INFO) << "NCCL version: " << nccl_version_major << "." << nccl_version_minor << "." LOG(INFO) << "NCCL version: " << nccl_version_major << "." << nccl_version_minor << "."
<< nccl_version_patch; << nccl_version_patch;
} else { } else {
LOG(ERROR) << "Failed to get NCCL version: " << ncclGetErrorString(result); LOG(ERROR) << "Failed to get NCCL version: " << ncclGetErrorString(result);
} }
} }
} }
class CudaDeviceManagerFactory : public DeviceManagerFactory { class CudaDeviceManagerFactory : public DeviceManagerFactory {
public: public:
OF_DISALLOW_COPY_AND_MOVE(CudaDeviceManagerFactory); OF_DISALLOW_COPY_AND_MOVE(CudaDeviceManagerFactory);
CudaDeviceManagerFactory() = default; CudaDeviceManagerFactory() = default;
~CudaDeviceManagerFactory() override = default; ~CudaDeviceManagerFactory() override = default;
std::unique_ptr<DeviceManager> NewDeviceManager(DeviceManagerRegistry* registry) override { std::unique_ptr<DeviceManager> NewDeviceManager(DeviceManagerRegistry* registry) override {
return std::make_unique<CudaDeviceManager>(registry); return std::make_unique<CudaDeviceManager>(registry);
} }
DeviceType device_type() const override { return DeviceType::kCUDA; } DeviceType device_type() const override { return DeviceType::kCUDA; }
std::string device_type_name() const override { return "cuda"; } std::string device_type_name() const override { return "cuda"; }
void DumpVersionInfo() const override { CudaDumpVersionInfo(); } void DumpVersionInfo() const override { CudaDumpVersionInfo(); }
}; };
COMMAND(DeviceManagerRegistry::RegisterDeviceManagerFactory( COMMAND(DeviceManagerRegistry::RegisterDeviceManagerFactory(
std::make_unique<CudaDeviceManagerFactory>())) std::make_unique<CudaDeviceManagerFactory>()))
} // namespace } // namespace
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
#endif // WITH_ROCM #endif // WITH_ROCM
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/ep/rocm/cuda_event.h" #include "oneflow/core/ep/rocm/cuda_event.h"
#ifdef WITH_ROCM #ifdef WITH_ROCM
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
CudaEvent::CudaEvent(unsigned int flags) : cuda_event_{} { CudaEvent::CudaEvent(unsigned int flags) : cuda_event_{} {
OF_CUDA_CHECK(hipEventCreateWithFlags(&cuda_event_, flags)); OF_CUDA_CHECK(hipEventCreateWithFlags(&cuda_event_, flags));
} }
CudaEvent::~CudaEvent() { OF_CUDA_CHECK(hipEventDestroy(cuda_event_)); } CudaEvent::~CudaEvent() { OF_CUDA_CHECK(hipEventDestroy(cuda_event_)); }
Maybe<bool> CudaEvent::QueryDone() { Maybe<bool> CudaEvent::QueryDone() {
hipError_t err = hipEventQuery(cuda_event_); hipError_t err = hipEventQuery(cuda_event_);
if (err == hipSuccess) { if (err == hipSuccess) {
return Maybe<bool>(true); return Maybe<bool>(true);
} else if (err == hipErrorNotReady) { } else if (err == hipErrorNotReady) {
return Maybe<bool>(false); return Maybe<bool>(false);
} else { } else {
return Error::RuntimeError() << hipGetErrorString(err); return Error::RuntimeError() << hipGetErrorString(err);
} }
} }
Maybe<void> CudaEvent::Sync() { Maybe<void> CudaEvent::Sync() {
hipError_t err = hipEventSynchronize(cuda_event_); hipError_t err = hipEventSynchronize(cuda_event_);
if (err == hipSuccess) { if (err == hipSuccess) {
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} else { } else {
return Error::RuntimeError() << hipGetErrorString(err); return Error::RuntimeError() << hipGetErrorString(err);
} }
} }
hipEvent_t CudaEvent::cuda_event() { return cuda_event_; } hipEvent_t CudaEvent::cuda_event() { return cuda_event_; }
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
#endif // WITH_ROCM #endif // WITH_ROCM
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#ifndef ONEFLOW_CORE_EP_ROCM_CUDA_EVENT_H_ #ifndef ONEFLOW_CORE_EP_ROCM_CUDA_EVENT_H_
#define ONEFLOW_CORE_EP_ROCM_CUDA_EVENT_H_ #define ONEFLOW_CORE_EP_ROCM_CUDA_EVENT_H_
#include "oneflow/core/ep/include/event.h" #include "oneflow/core/ep/include/event.h"
#ifdef WITH_ROCM #ifdef WITH_ROCM
#include "oneflow/core/device/cuda_util.h" #include "oneflow/core/device/cuda_util.h"
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
class CudaEvent : public Event { class CudaEvent : public Event {
public: public:
OF_DISALLOW_COPY_AND_MOVE(CudaEvent); OF_DISALLOW_COPY_AND_MOVE(CudaEvent);
explicit CudaEvent(unsigned int flags); explicit CudaEvent(unsigned int flags);
~CudaEvent() override; ~CudaEvent() override;
Maybe<bool> QueryDone() override; Maybe<bool> QueryDone() override;
Maybe<void> Sync() override; Maybe<void> Sync() override;
hipEvent_t cuda_event(); hipEvent_t cuda_event();
private: private:
hipEvent_t cuda_event_; hipEvent_t cuda_event_;
}; };
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
#endif // WITH_ROCM #endif // WITH_ROCM
#endif // ONEFLOW_CORE_EP_ROCM_CUDA_EVENT_H_ #endif // ONEFLOW_CORE_EP_ROCM_CUDA_EVENT_H_
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/ep/rocm/cuda_stream.h" #include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/job/global_for.h" #include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/resource_desc.h" #include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/hardware/node_device_descriptor_manager.h" #include "oneflow/core/hardware/node_device_descriptor_manager.h"
#include "oneflow/core/hardware/cuda_device_descriptor.h" #include "oneflow/core/hardware/cuda_device_descriptor.h"
#include "oneflow/core/ep/rocm/cuda_event.h" #include "oneflow/core/ep/rocm/cuda_event.h"
#include "oneflow/core/ep/rocm/cuda_device.h" #include "oneflow/core/ep/rocm/cuda_device.h"
#ifdef WITH_ROCM #ifdef WITH_ROCM
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
namespace { namespace {
constexpr size_t kDefaultWorkspaceSize = 4 * 1024 * 1024; // 4M constexpr size_t kDefaultWorkspaceSize = 4 * 1024 * 1024; // 4M
void SetAffinityByDevice(int dev_id) { void SetAffinityByDevice(int dev_id) {
auto node_device_desc_mgr = Singleton<hardware::NodeDeviceDescriptorManager>::Get(); auto node_device_desc_mgr = Singleton<hardware::NodeDeviceDescriptorManager>::Get();
if (node_device_desc_mgr == nullptr) { return; } if (node_device_desc_mgr == nullptr) { return; }
auto node_device_desc = node_device_desc_mgr->GetLocalNodeDeviceDescriptor(); auto node_device_desc = node_device_desc_mgr->GetLocalNodeDeviceDescriptor();
auto cuda_device = std::dynamic_pointer_cast<const hardware::CudaDeviceDescriptor>( auto cuda_device = std::dynamic_pointer_cast<const hardware::CudaDeviceDescriptor>(
node_device_desc->GetDevice(hardware::kCudaDeviceDescriptorClassName, dev_id)); node_device_desc->GetDevice(hardware::kCudaDeviceDescriptorClassName, dev_id));
if (!cuda_device) { return; } if (!cuda_device) { return; }
node_device_desc->Topology()->SetCPUAffinityByPCIBusID(cuda_device->PCIBusID()); node_device_desc->Topology()->SetCPUAffinityByPCIBusID(cuda_device->PCIBusID());
node_device_desc->Topology()->SetMemoryAffinityByPCIBusID(cuda_device->PCIBusID()); node_device_desc->Topology()->SetMemoryAffinityByPCIBusID(cuda_device->PCIBusID());
} }
} // namespace } // namespace
#ifdef WITH_ROCM_GRAPHS #ifdef WITH_ROCM_GRAPHS
CudaGraphExecutable::CudaGraphExecutable() : graph_exec_(nullptr), dev_(-1) {} CudaGraphExecutable::CudaGraphExecutable() : graph_exec_(nullptr), dev_(-1) {}
CudaGraphExecutable::~CudaGraphExecutable() { Reset(); } CudaGraphExecutable::~CudaGraphExecutable() { Reset(); }
void CudaGraphExecutable::Update(hipGraph_t graph) { void CudaGraphExecutable::Update(hipGraph_t graph) {
int dev = -1; int dev = -1;
OF_CUDA_CHECK(hipGetDevice(&dev)); OF_CUDA_CHECK(hipGetDevice(&dev));
if (dev != dev_) { Reset(); } if (dev != dev_) { Reset(); }
dev_ = dev; dev_ = dev;
if (graph_exec_ != nullptr) { if (graph_exec_ != nullptr) {
hipGraphExecUpdateResult update_result{}; hipGraphExecUpdateResult update_result{};
hipGraphNode_t error_node = nullptr; hipGraphNode_t error_node = nullptr;
OF_CUDA_CHECK(hipGraphExecUpdate(graph_exec_, graph, &error_node, &update_result)); OF_CUDA_CHECK(hipGraphExecUpdate(graph_exec_, graph, &error_node, &update_result));
if (update_result == hipGraphExecUpdateSuccess) { return; } if (update_result == hipGraphExecUpdateSuccess) { return; }
} }
Reset(); Reset();
OF_CUDA_CHECK(hipGraphInstantiate(&graph_exec_, graph, NULL, NULL, 0)); OF_CUDA_CHECK(hipGraphInstantiate(&graph_exec_, graph, NULL, NULL, 0));
} }
void CudaGraphExecutable::Launch(hipStream_t stream) const { void CudaGraphExecutable::Launch(hipStream_t stream) const {
OF_CUDA_CHECK(hipGraphLaunch(graph_exec_, stream)); OF_CUDA_CHECK(hipGraphLaunch(graph_exec_, stream));
} }
bool CudaGraphExecutable::IsInstantiated() const { return graph_exec_ != nullptr; } bool CudaGraphExecutable::IsInstantiated() const { return graph_exec_ != nullptr; }
void CudaGraphExecutable::Reset() { void CudaGraphExecutable::Reset() {
if (graph_exec_ != nullptr) { if (graph_exec_ != nullptr) {
CudaCurrentDeviceGuard guard(dev_); CudaCurrentDeviceGuard guard(dev_);
OF_CUDA_CHECK(hipGraphExecDestroy(graph_exec_)); OF_CUDA_CHECK(hipGraphExecDestroy(graph_exec_));
} }
} }
#endif // WITH_ROCM_GRAPHS #endif // WITH_ROCM_GRAPHS
CudaStream::CudaStream(CudaDevice* device) CudaStream::CudaStream(CudaDevice* device)
: device_index_(device->device_index()), device_(device) { : device_index_(device->device_index()), device_(device) {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
// cuda_stream // cuda_stream
OF_CUDA_CHECK(hipStreamCreate(&cuda_stream_)); OF_CUDA_CHECK(hipStreamCreate(&cuda_stream_));
// cublas_handle // cublas_handle
OF_CUBLAS_CHECK(hipblasCreate(&cublas_handle_)); OF_CUBLAS_CHECK(hipblasCreate(&cublas_handle_));
OF_CUBLAS_CHECK(hipblasSetStream(cublas_handle_, cuda_stream_)); OF_CUBLAS_CHECK(hipblasSetStream(cublas_handle_, cuda_stream_));
workspace_size_ = kDefaultWorkspaceSize; workspace_size_ = kDefaultWorkspaceSize;
OF_CUDA_CHECK(hipMalloc(&workspace_, workspace_size_)); OF_CUDA_CHECK(hipMalloc(&workspace_, workspace_size_));
OF_CUDNN_CHECK(hipdnnCreate(&cudnn_handle_)); OF_CUDNN_CHECK(hipdnnCreate(&cudnn_handle_));
OF_CUDNN_CHECK(hipdnnSetStream(cudnn_handle_, cuda_stream_)); OF_CUDNN_CHECK(hipdnnSetStream(cudnn_handle_, cuda_stream_));
} }
CudaStream::~CudaStream() { CudaStream::~CudaStream() {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
OF_CUDA_CHECK(hipStreamSynchronize(cuda_stream_)); OF_CUDA_CHECK(hipStreamSynchronize(cuda_stream_));
OF_CUDNN_CHECK(hipdnnDestroy(cudnn_handle_)); OF_CUDNN_CHECK(hipdnnDestroy(cudnn_handle_));
OF_CUBLAS_CHECK(hipblasDestroy(cublas_handle_)); OF_CUBLAS_CHECK(hipblasDestroy(cublas_handle_));
OF_CUDA_CHECK(hipStreamDestroy(cuda_stream_)); OF_CUDA_CHECK(hipStreamDestroy(cuda_stream_));
OF_CUDA_CHECK(hipFree(workspace_)); OF_CUDA_CHECK(hipFree(workspace_));
} }
Maybe<void> CudaStream::OnExecutionContextSetup() { Maybe<void> CudaStream::OnExecutionContextSetup() {
OF_CUDA_CHECK(hipSetDevice(device_index_)); OF_CUDA_CHECK(hipSetDevice(device_index_));
SetAffinityByDevice(device_index_); SetAffinityByDevice(device_index_);
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
Maybe<void> CudaStream::OnExecutionContextTeardown() { return Maybe<void>::Ok(); } Maybe<void> CudaStream::OnExecutionContextTeardown() { return Maybe<void>::Ok(); }
DeviceType CudaStream::device_type() const { return DeviceType::kCUDA; } DeviceType CudaStream::device_type() const { return DeviceType::kCUDA; }
CudaDevice* CudaStream::device() const { return device_; } CudaDevice* CudaStream::device() const { return device_; }
Maybe<void> CudaStream::Sync() { Maybe<void> CudaStream::Sync() {
hipError_t err = hipStreamSynchronize(cuda_stream_); hipError_t err = hipStreamSynchronize(cuda_stream_);
if (err == hipSuccess) { if (err == hipSuccess) {
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} else { } else {
return Error::RuntimeError() << hipGetErrorString(err) << " (" << err << ") "; return Error::RuntimeError() << hipGetErrorString(err) << " (" << err << ") ";
} }
} }
void CudaStream::RecordEvent(Event* event) { void CudaStream::RecordEvent(Event* event) {
auto* cuda_event = static_cast<CudaEvent*>(event); // NOLINT auto* cuda_event = static_cast<CudaEvent*>(event); // NOLINT
OF_CUDA_CHECK(hipEventRecord(cuda_event->cuda_event(), cuda_stream_)); OF_CUDA_CHECK(hipEventRecord(cuda_event->cuda_event(), cuda_stream_));
} }
hipStream_t CudaStream::cuda_stream() const { return cuda_stream_; } hipStream_t CudaStream::cuda_stream() const { return cuda_stream_; }
hipblasHandle_t CudaStream::cublas_handle() const { return cublas_handle_; } hipblasHandle_t CudaStream::cublas_handle() const { return cublas_handle_; }
void* CudaStream::cublas_workspace() const { return workspace_; } void* CudaStream::cublas_workspace() const { return workspace_; }
size_t CudaStream::cublas_workspace_size() const { return workspace_size_; } size_t CudaStream::cublas_workspace_size() const { return workspace_size_; }
hipdnnHandle_t CudaStream::cudnn_handle() const { return cudnn_handle_; } hipdnnHandle_t CudaStream::cudnn_handle() const { return cudnn_handle_; }
const hipDeviceProp_t& CudaStream::device_properties() const { return device_->properties(); } const hipDeviceProp_t& CudaStream::device_properties() const { return device_->properties(); }
int CudaStream::cuda_arch() const { int CudaStream::cuda_arch() const {
return device_->properties().major * 100 + device_->properties().minor * 10; return device_->properties().major * 100 + device_->properties().minor * 10;
} }
#ifdef WITH_ROCM_GRAPHS #ifdef WITH_ROCM_GRAPHS
void CudaStream::BeginGraphCapture() { void CudaStream::BeginGraphCapture() {
CHECK(!is_graph_capturing_); CHECK(!is_graph_capturing_);
is_graph_capturing_ = true; is_graph_capturing_ = true;
OF_CUDA_CHECK(hipStreamBeginCapture(cuda_stream_, hipStreamCaptureModeThreadLocal)); OF_CUDA_CHECK(hipStreamBeginCapture(cuda_stream_, hipStreamCaptureModeThreadLocal));
} }
void CudaStream::EndGraphCapture(CudaGraphExecutable* executable) { void CudaStream::EndGraphCapture(CudaGraphExecutable* executable) {
hipGraph_t graph = nullptr; hipGraph_t graph = nullptr;
OF_CUDA_CHECK(hipStreamEndCapture(cuda_stream_, &graph)); OF_CUDA_CHECK(hipStreamEndCapture(cuda_stream_, &graph));
executable->Update(graph); executable->Update(graph);
OF_CUDA_CHECK(hipGraphDestroy(graph)); OF_CUDA_CHECK(hipGraphDestroy(graph));
is_graph_capturing_ = false; is_graph_capturing_ = false;
} }
bool CudaStream::IsGraphCapturing() const { return is_graph_capturing_; } bool CudaStream::IsGraphCapturing() const { return is_graph_capturing_; }
void CudaStream::LaunchGraph(const CudaGraphExecutable* executable) { void CudaStream::LaunchGraph(const CudaGraphExecutable* executable) {
executable->Launch(cuda_stream_); executable->Launch(cuda_stream_);
} }
#endif // WITH_ROCM_GRAPHS #endif // WITH_ROCM_GRAPHS
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
#endif // WITH_ROCM #endif // WITH_ROCM
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#ifndef ONEFLOW_CORE_EP_ROCM_CUDA_STREAM_H_ #ifndef ONEFLOW_CORE_EP_ROCM_CUDA_STREAM_H_
#define ONEFLOW_CORE_EP_ROCM_CUDA_STREAM_H_ #define ONEFLOW_CORE_EP_ROCM_CUDA_STREAM_H_
#include "oneflow/core/ep/include/stream.h" #include "oneflow/core/ep/include/stream.h"
#include "oneflow/core/ep/rocm/cuda_device.h" #include "oneflow/core/ep/rocm/cuda_device.h"
#ifdef WITH_ROCM #ifdef WITH_ROCM
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include "oneflow/core/hipdnn/hipdnn.h" #include "oneflow/core/hipdnn/hipdnn.h"
// #if CUDA_VERSION >= 11000 // #if CUDA_VERSION >= 11000
// #define WITH_ROCM_GRAPHS // #define WITH_ROCM_GRAPHS
// #endif // CUDA_VERSION >= 11000 // #endif // CUDA_VERSION >= 11000
#include "oneflow/core/device/cuda_util.h" #include "oneflow/core/device/cuda_util.h"
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
class CudaDevice; class CudaDevice;
#ifdef WITH_ROCM_GRAPHS #ifdef WITH_ROCM_GRAPHS
class CudaGraphExecutable { class CudaGraphExecutable {
public: public:
OF_DISALLOW_COPY_AND_MOVE(CudaGraphExecutable); OF_DISALLOW_COPY_AND_MOVE(CudaGraphExecutable);
CudaGraphExecutable(); CudaGraphExecutable();
~CudaGraphExecutable(); ~CudaGraphExecutable();
void Update(hipGraph_t graph); void Update(hipGraph_t graph);
void Launch(hipStream_t stream) const; void Launch(hipStream_t stream) const;
bool IsInstantiated() const; bool IsInstantiated() const;
private: private:
void Reset(); void Reset();
hipGraphExec_t graph_exec_; hipGraphExec_t graph_exec_;
int dev_; int dev_;
}; };
#endif // WITH_ROCM_GRAPHS #endif // WITH_ROCM_GRAPHS
struct CudaLaunchConfig { struct CudaLaunchConfig {
dim3 grid_dim; dim3 grid_dim;
dim3 block_dim; dim3 block_dim;
size_t shared_mem_size; size_t shared_mem_size;
CudaLaunchConfig() : grid_dim{}, block_dim{}, shared_mem_size(0) {} CudaLaunchConfig() : grid_dim{}, block_dim{}, shared_mem_size(0) {}
CudaLaunchConfig(unsigned int grid_size, unsigned int block_size, size_t shared_mem_size) CudaLaunchConfig(unsigned int grid_size, unsigned int block_size, size_t shared_mem_size)
: grid_dim(grid_size), block_dim(block_size), shared_mem_size(shared_mem_size) {} : grid_dim(grid_size), block_dim(block_size), shared_mem_size(shared_mem_size) {}
}; };
class CudaStream : public Stream { class CudaStream : public Stream {
public: public:
OF_DISALLOW_COPY_AND_MOVE(CudaStream); OF_DISALLOW_COPY_AND_MOVE(CudaStream);
explicit CudaStream(CudaDevice* device); explicit CudaStream(CudaDevice* device);
~CudaStream() override; ~CudaStream() override;
static constexpr uint32_t kDefaultBlockSize = 256; static constexpr uint32_t kDefaultBlockSize = 256;
DeviceType device_type() const override; DeviceType device_type() const override;
CudaDevice* device() const override; CudaDevice* device() const override;
Maybe<void> Sync() override; Maybe<void> Sync() override;
void RecordEvent(Event* event) override; void RecordEvent(Event* event) override;
Maybe<void> OnExecutionContextSetup() override; Maybe<void> OnExecutionContextSetup() override;
Maybe<void> OnExecutionContextTeardown() override; Maybe<void> OnExecutionContextTeardown() override;
hipStream_t cuda_stream() const; hipStream_t cuda_stream() const;
hipblasHandle_t cublas_handle() const; hipblasHandle_t cublas_handle() const;
// #if CUDA_VERSION >= 10010 // #if CUDA_VERSION >= 10010
// cublasLtHandle_t cublas_lt_handle() const; // cublasLtHandle_t cublas_lt_handle() const;
// #endif // #endif
hipdnnHandle_t cudnn_handle() const; hipdnnHandle_t cudnn_handle() const;
void* cublas_workspace() const; void* cublas_workspace() const;
size_t cublas_workspace_size() const; size_t cublas_workspace_size() const;
const hipDeviceProp_t& device_properties() const; const hipDeviceProp_t& device_properties() const;
int cuda_arch() const; int cuda_arch() const;
void InitLaunchConfigWithWaves(CudaLaunchConfig* config, size_t elem_cnt, size_t block_size, void InitLaunchConfigWithWaves(CudaLaunchConfig* config, size_t elem_cnt, size_t block_size,
size_t max_waves) const { size_t max_waves) const {
const uint32_t max_grid_size = max_waves * device_properties().multiProcessorCount const uint32_t max_grid_size = max_waves * device_properties().multiProcessorCount
* (device_properties().maxThreadsPerMultiProcessor / block_size); * (device_properties().maxThreadsPerMultiProcessor / block_size);
const uint32_t grid_size = const uint32_t grid_size =
std::min<uint32_t>(max_grid_size, (elem_cnt + block_size - 1) / block_size); std::min<uint32_t>(max_grid_size, (elem_cnt + block_size - 1) / block_size);
config->grid_dim = dim3(grid_size); config->grid_dim = dim3(grid_size);
config->block_dim = dim3(block_size); config->block_dim = dim3(block_size);
config->shared_mem_size = 0; config->shared_mem_size = 0;
} }
#ifdef __HIPCC__ #ifdef __HIPCC__
template<typename... Params, typename... Args> template<typename... Params, typename... Args>
void LaunchKernel(void (*kernel)(Params...), const CudaLaunchConfig& launch_config, void LaunchKernel(void (*kernel)(Params...), const CudaLaunchConfig& launch_config,
Args... args) { Args... args) {
kernel<<<launch_config.grid_dim, launch_config.block_dim, launch_config.shared_mem_size, kernel<<<launch_config.grid_dim, launch_config.block_dim, launch_config.shared_mem_size,
cuda_stream()>>>(args...); cuda_stream()>>>(args...);
} }
template<typename... Params, typename... Args> template<typename... Params, typename... Args>
void LaunchKernel(void (*kernel)(Params...), size_t elem_cnt, size_t max_waves, Args... args) { void LaunchKernel(void (*kernel)(Params...), size_t elem_cnt, size_t max_waves, Args... args) {
constexpr uint32_t block_size = kDefaultBlockSize; constexpr uint32_t block_size = kDefaultBlockSize;
CudaLaunchConfig config{}; CudaLaunchConfig config{};
InitLaunchConfigWithWaves(&config, elem_cnt, block_size, max_waves); InitLaunchConfigWithWaves(&config, elem_cnt, block_size, max_waves);
LaunchKernel(kernel, config, args...); LaunchKernel(kernel, config, args...);
} }
template<typename... Params, typename... Args> template<typename... Params, typename... Args>
void LaunchKernelDefaultWaves(void (*kernel)(Params...), size_t elem_cnt, Args... args) { void LaunchKernelDefaultWaves(void (*kernel)(Params...), size_t elem_cnt, Args... args) {
const size_t default_waves = 32; const size_t default_waves = 32;
LaunchKernel(kernel, elem_cnt, default_waves, args...); LaunchKernel(kernel, elem_cnt, default_waves, args...);
} }
#endif // __HIPCC__ #endif // __HIPCC__
#ifdef WITH_ROCM_GRAPHS #ifdef WITH_ROCM_GRAPHS
void BeginGraphCapture(); void BeginGraphCapture();
void EndGraphCapture(CudaGraphExecutable* executable); void EndGraphCapture(CudaGraphExecutable* executable);
bool IsGraphCapturing() const; bool IsGraphCapturing() const;
void LaunchGraph(const CudaGraphExecutable* executable); void LaunchGraph(const CudaGraphExecutable* executable);
#endif // WITH_ROCM_GRAPHS #endif // WITH_ROCM_GRAPHS
private: private:
hipStream_t cuda_stream_{}; hipStream_t cuda_stream_{};
hipblasHandle_t cublas_handle_{}; hipblasHandle_t cublas_handle_{};
// #if CUDA_VERSION >= 10010 // #if CUDA_VERSION >= 10010
// cublasLtHandle_t cublas_lt_handle_{}; // cublasLtHandle_t cublas_lt_handle_{};
// #endif // #endif
hipdnnHandle_t cudnn_handle_{}; hipdnnHandle_t cudnn_handle_{};
int device_index_; int device_index_;
void* workspace_{}; void* workspace_{};
size_t workspace_size_{}; size_t workspace_size_{};
#ifdef WITH_ROCM_GRAPHS #ifdef WITH_ROCM_GRAPHS
bool is_graph_capturing_{}; bool is_graph_capturing_{};
#endif // WITH_ROCM_GRAPHS #endif // WITH_ROCM_GRAPHS
CudaDevice* device_; CudaDevice* device_;
}; };
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
#endif // WITH_ROCM #endif // WITH_ROCM
#endif // ONEFLOW_CORE_EP_ROCM_CUDA_STREAM_H_ #endif // ONEFLOW_CORE_EP_ROCM_CUDA_STREAM_H_
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/common/preprocessor.h" #include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/ep/include/primitive/add.h" #include "oneflow/core/ep/include/primitive/add.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h" #include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/hip/elementwise.hip.h" #include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/cuda_stream.h" #include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/device/cuda_pseudo_bfloat16.h" #include "oneflow/core/device/cuda_pseudo_bfloat16.h"
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
namespace primitive { namespace primitive {
namespace { namespace {
template<typename... Args> template<typename... Args>
struct AddFunctor; struct AddFunctor;
template<typename T> template<typename T>
struct AddFunctor<T> { struct AddFunctor<T> {
__device__ T operator()(T x) const { return x; } __device__ T operator()(T x) const { return x; }
}; };
template<typename T, typename U, typename... Args> template<typename T, typename U, typename... Args>
struct AddFunctor<T, U, Args...> { struct AddFunctor<T, U, Args...> {
__device__ T operator()(T x0, U x1, Args... xs) const { __device__ T operator()(T x0, U x1, Args... xs) const {
return x0 + AddFunctor<U, Args...>()(x1, xs...); return x0 + AddFunctor<U, Args...>()(x1, xs...);
} }
}; };
template<typename T, typename... Args> template<typename T, typename... Args>
__global__ void AddGpu(const Args*... srcs, T* dst, size_t count) { __global__ void AddGpu(const Args*... srcs, T* dst, size_t count) {
CUDA_1D_KERNEL_LOOP_T(size_t, i, count) { dst[i] = AddFunctor<Args...>()(srcs[i]...); } CUDA_1D_KERNEL_LOOP_T(size_t, i, count) { dst[i] = AddFunctor<Args...>()(srcs[i]...); }
} }
template<typename T, typename... Args> template<typename T, typename... Args>
void LaunchAddGpu(hipStream_t stream, const Args*... srcs, T* dst, size_t count) { void LaunchAddGpu(hipStream_t stream, const Args*... srcs, T* dst, size_t count) {
AddGpu<T, Args...> AddGpu<T, Args...>
<<<BlocksNum4ThreadsNum(count), kCudaThreadsNumPerBlock, 0, stream>>>(srcs..., dst, count); <<<BlocksNum4ThreadsNum(count), kCudaThreadsNumPerBlock, 0, stream>>>(srcs..., dst, count);
} }
template<typename T> template<typename T>
void DispatchLaunch(hipStream_t stream, const T* const* srcs, size_t arity, T* dst, size_t count) { void DispatchLaunch(hipStream_t stream, const T* const* srcs, size_t arity, T* dst, size_t count) {
if (arity == 0) { if (arity == 0) {
OF_CUDA_CHECK(hipMemsetAsync(dst, 0, count * sizeof(T), stream)); OF_CUDA_CHECK(hipMemsetAsync(dst, 0, count * sizeof(T), stream));
} else if (arity == 1) { } else if (arity == 1) {
OF_CUDA_CHECK(hipMemcpyAsync(dst, srcs[0], count * sizeof(T), hipMemcpyDefault, stream)); OF_CUDA_CHECK(hipMemcpyAsync(dst, srcs[0], count * sizeof(T), hipMemcpyDefault, stream));
} else if (arity == 2) { } else if (arity == 2) {
OF_CUDA_CHECK((cuda::elementwise::Binary<AddFunctor<T, T>, T, T, T>( OF_CUDA_CHECK((cuda::elementwise::Binary<AddFunctor<T, T>, T, T, T>(
AddFunctor<T, T>(), count, dst, srcs[0], srcs[1], stream))); AddFunctor<T, T>(), count, dst, srcs[0], srcs[1], stream)));
} else if (arity == 3) { } else if (arity == 3) {
OF_CUDA_CHECK((cuda::elementwise::Ternary<AddFunctor<T, T, T>, T, T, T, T>( OF_CUDA_CHECK((cuda::elementwise::Ternary<AddFunctor<T, T, T>, T, T, T, T>(
AddFunctor<T, T, T>(), count, dst, srcs[0], srcs[1], srcs[2], stream))); AddFunctor<T, T, T>(), count, dst, srcs[0], srcs[1], srcs[2], stream)));
} else if (arity == 4) { } else if (arity == 4) {
LaunchAddGpu<T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], dst, count); LaunchAddGpu<T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], dst, count);
} else if (arity == 5) { } else if (arity == 5) {
LaunchAddGpu<T, T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4], dst, count); LaunchAddGpu<T, T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4], dst, count);
} else if (arity == 6) { } else if (arity == 6) {
LaunchAddGpu<T, T, T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4], srcs[5], LaunchAddGpu<T, T, T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4], srcs[5],
dst, count); dst, count);
} else if (arity == 7) { } else if (arity == 7) {
LaunchAddGpu<T, T, T, T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4], LaunchAddGpu<T, T, T, T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4],
srcs[5], srcs[6], dst, count); srcs[5], srcs[6], dst, count);
} else if (arity == 8) { } else if (arity == 8) {
LaunchAddGpu<T, T, T, T, T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4], LaunchAddGpu<T, T, T, T, T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4],
srcs[5], srcs[6], srcs[7], dst, count); srcs[5], srcs[6], srcs[7], dst, count);
} else { } else {
DispatchLaunch(stream, srcs + 7, arity - 7, dst, count); DispatchLaunch(stream, srcs + 7, arity - 7, dst, count);
LaunchAddGpu<T, T, T, T, T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4], LaunchAddGpu<T, T, T, T, T, T, T, T, T>(stream, srcs[0], srcs[1], srcs[2], srcs[3], srcs[4],
srcs[5], srcs[6], dst, dst, count); srcs[5], srcs[6], dst, dst, count);
} }
} }
template<typename T> template<typename T>
class AddImpl : public Add { class AddImpl : public Add {
public: public:
OF_DISALLOW_COPY_AND_MOVE(AddImpl); OF_DISALLOW_COPY_AND_MOVE(AddImpl);
AddImpl() = default; AddImpl() = default;
~AddImpl() override = default; ~AddImpl() override = default;
using Add::Launch; using Add::Launch;
void Launch(Stream* stream, const void* const* srcs, size_t arity, void* dst, void Launch(Stream* stream, const void* const* srcs, size_t arity, void* dst,
size_t count) override { size_t count) override {
hipStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream(); hipStream_t cuda_stream = stream->As<CudaStream>()->cuda_stream();
DispatchLaunch(cuda_stream, reinterpret_cast<const T* const*>(srcs), arity, DispatchLaunch(cuda_stream, reinterpret_cast<const T* const*>(srcs), arity,
reinterpret_cast<T*>(dst), count); reinterpret_cast<T*>(dst), count);
} }
}; };
template<typename T> template<typename T>
std::unique_ptr<Add> NewAdd() { std::unique_ptr<Add> NewAdd() {
return std::unique_ptr<Add>(new AddImpl<T>()); return std::unique_ptr<Add>(new AddImpl<T>());
} }
class AddFactoryImpl : public AddFactory { class AddFactoryImpl : public AddFactory {
public: public:
OF_DISALLOW_COPY_AND_MOVE(AddFactoryImpl); OF_DISALLOW_COPY_AND_MOVE(AddFactoryImpl);
AddFactoryImpl() = default; AddFactoryImpl() = default;
~AddFactoryImpl() override = default; ~AddFactoryImpl() override = default;
std::unique_ptr<Add> New(DataType data_type) override { std::unique_ptr<Add> New(DataType data_type) override {
#define MAKE_NEW_ADD_ENTRY(type_cpp, type_proto) {type_proto, NewAdd<type_cpp>}, #define MAKE_NEW_ADD_ENTRY(type_cpp, type_proto) {type_proto, NewAdd<type_cpp>},
static const std::map<DataType, std::function<std::unique_ptr<Add>()>> new_add_handle{ static const std::map<DataType, std::function<std::unique_ptr<Add>()>> new_add_handle{
OF_PP_FOR_EACH_TUPLE(MAKE_NEW_ADD_ENTRY, CUDA_PRIMITIVE_ALL_TYPE_SEQ)}; OF_PP_FOR_EACH_TUPLE(MAKE_NEW_ADD_ENTRY, CUDA_PRIMITIVE_ALL_TYPE_SEQ)};
#undef MAKE_NEW_ADD_ENTRY #undef MAKE_NEW_ADD_ENTRY
const auto it = new_add_handle.find(data_type); const auto it = new_add_handle.find(data_type);
if (it != new_add_handle.end()) { if (it != new_add_handle.end()) {
return it->second(); return it->second();
} else { } else {
return nullptr; return nullptr;
} }
} }
}; };
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, AddFactory, AddFactoryImpl); REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, AddFactory, AddFactoryImpl);
} // namespace } // namespace
} // namespace primitive } // namespace primitive
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment