Commit a715222c authored by yuguo's avatar yuguo
Browse files

0.9.1-rocm

parent f262efc9
......@@ -24,7 +24,470 @@ namespace embedding {
#ifdef WITH_CUDA
constexpr size_t kDefaultMaxQueryLength = 65536;
constexpr size_t kDefaultMaxQueryLength = 131072;
constexpr int64_t kRingBufferSize = 8;
struct IdStatistics {
IdStatistics() : final_num_unique(0), iter(-1) {}
uint32_t final_num_unique;
std::vector<uint32_t> num_unique_matrix;
int64_t iter;
};
#if CUDA_VERSION >= 11020
class DynamicTmpBufferAllocator final : public TmpBufferAllocator {
public:
OF_DISALLOW_COPY_AND_MOVE(DynamicTmpBufferAllocator);
DynamicTmpBufferAllocator(cudaStream_t stream, cudaMemPool_t pool)
: stream_(stream), mem_pool_(pool) {}
~DynamicTmpBufferAllocator() override = default;
void Allocate(void** ptr, size_t size) override {
OF_CUDA_CHECK(cudaMallocFromPoolAsync(ptr, GetCudaAlignedSize(size), mem_pool_, stream_));
}
void Free(void* ptr) override { OF_CUDA_CHECK(cudaFreeAsync(ptr, stream_)); }
private:
cudaStream_t stream_{};
cudaMemPool_t mem_pool_{};
};
class DynamicAllocationEmbeddingState final : public EmbeddingState {
public:
OF_DISALLOW_COPY_AND_MOVE(DynamicAllocationEmbeddingState);
DynamicAllocationEmbeddingState()
: lookup_values_(nullptr),
lookup_values_size_(0),
has_lookup_values_(false),
lookup_embeddings_(nullptr),
lookup_embeddings_size_(0),
has_lookup_embeddings_(false),
updated_values_(nullptr),
iter_(-1) {
OF_CUDA_CHECK(cudaGetDevice(&device_index_));
id_statistics_vec_.resize(kRingBufferSize);
cudaMemPoolProps poolProps = {};
poolProps.allocType = cudaMemAllocationTypePinned;
poolProps.handleTypes = cudaMemHandleTypePosixFileDescriptor;
poolProps.location.type = cudaMemLocationTypeDevice;
poolProps.location.id = device_index_;
cudaMemPoolCreate(&mem_pool_, &poolProps);
uint64_t threshold = UINT64_MAX;
cudaMemPoolSetAttribute(mem_pool_, cudaMemPoolAttrReleaseThreshold, &threshold);
}
~DynamicAllocationEmbeddingState() {
CudaCurrentDeviceGuard guard(device_index_);
if (has_lookup_values_) { OF_CUDA_CHECK(cudaFree(lookup_values_)); }
if (has_lookup_embeddings_) { OF_CUDA_CHECK(cudaFree(lookup_embeddings_)); }
OF_CUDA_CHECK(cudaMemPoolDestroy(mem_pool_));
}
std::unique_ptr<TmpBufferAllocator> NewTmpBufferAllocator(
user_op::KernelComputeContext* ctx) override {
return std::make_unique<DynamicTmpBufferAllocator>(
ctx->stream()->As<ep::CudaStream>()->cuda_stream(), mem_pool_);
}
void OnEmbeddingLookupStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
iter_ = iter;
cudaStream_t cuda_stream = ctx->stream()->As<ep::CudaStream>()->cuda_stream();
user_op::Tensor* unique_values = ctx->Tensor4ArgNameAndIndex("unique_values", 0);
const int64_t embedding_size = ctx->Attr<int64_t>("embedding_size");
const int64_t line_size = ctx->Attr<int64_t>("line_size");
uint32_t num_unique = this->GetIdNumUnique(iter);
size_t lookup_values_size =
GetCudaAlignedSize(num_unique * line_size * GetSizeOfDataType(unique_values->data_type()));
if (!has_lookup_values_ || lookup_values_size_ < lookup_values_size) {
if (has_lookup_values_) { OF_CUDA_CHECK(cudaFreeAsync(lookup_values_, cuda_stream)); }
OF_CUDA_CHECK(
cudaMallocFromPoolAsync(&lookup_values_, lookup_values_size, mem_pool_, cuda_stream));
has_lookup_values_ = true;
lookup_values_size_ = lookup_values_size;
if (ctx->has_output("embeddings", 0)) {
user_op::Tensor* embeddings = ctx->Tensor4ArgNameAndIndex("embeddings", 0);
const size_t lookup_embeddings_size = GetCudaAlignedSize(
num_unique * embedding_size * GetSizeOfDataType(embeddings->data_type()));
if (!has_lookup_embeddings_ || lookup_embeddings_size_ < lookup_values_size) {
if (has_lookup_embeddings_) {
OF_CUDA_CHECK(cudaFreeAsync(lookup_embeddings_, cuda_stream));
}
OF_CUDA_CHECK(cudaMallocFromPoolAsync(&lookup_embeddings_, lookup_embeddings_size,
mem_pool_, cuda_stream));
has_lookup_embeddings_ = true;
lookup_embeddings_size_ = lookup_embeddings_size;
}
} else {
lookup_embeddings_ = nullptr;
}
}
}
void* LookupUniqueValues(int64_t iter) override {
CHECK_EQ(iter_, iter);
CHECK(has_lookup_values_);
return lookup_values_;
}
void* LookupEmbeddings(int64_t iter) override {
CHECK_EQ(iter_, iter);
CHECK(has_lookup_embeddings_);
return lookup_embeddings_;
}
void OnEmbeddingLookupEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {
// do nothing
}
void OnEmbeddingGatherStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
// do nothing
}
const void* EmbeddingGatherIn(int64_t iter) override {
if (has_lookup_embeddings_) {
return lookup_embeddings_;
} else {
CHECK(has_lookup_values_);
return lookup_values_;
}
}
void OnEmbeddingGatherEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {
// do nothing
}
void OnEmbeddingShuffleStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
// do nothing
}
const void* EmbeddingShuffleCurRankEmbeddings(int64_t iter) override {
if (has_lookup_embeddings_) {
return lookup_embeddings_;
} else {
CHECK(has_lookup_values_);
return lookup_values_;
}
}
void OnEmbeddingShuffleEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {
// do nothing
}
void OnEmbeddingUpdateStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
const user_op::Tensor* updated_unique_embeddings =
ctx->Tensor4ArgNameAndIndex("updated_unique_embeddings", 0);
const int64_t line_size = ctx->Attr<int64_t>("line_size");
uint32_t num_unique = this->GetIdNumUnique(iter);
size_t update_values_size = GetCudaAlignedSize(
num_unique * line_size * GetSizeOfDataType(updated_unique_embeddings->data_type()));
OF_CUDA_CHECK(cudaMallocFromPoolAsync(&updated_values_, update_values_size, mem_pool_,
ctx->stream()->As<ep::CudaStream>()->cuda_stream()));
}
const void* EmbeddingUpdateUniqueEmbeddings(int64_t iter) override {
CHECK_EQ(iter_, iter);
CHECK(has_lookup_values_);
return lookup_values_;
}
void* EmbeddingUpdateUpdatedUniqueEmbeddings(int64_t iter) override {
CHECK_EQ(iter_, iter);
return updated_values_;
}
void OnEmbeddingUpdateEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {
// do nothing
}
void OnEmbeddingPutStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
// do nothing
}
const void* EmbeddingPutUniqueEmbeddings(int64_t iter) override {
CHECK_EQ(iter_, iter);
return updated_values_;
}
void OnEmbeddingPutEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {
OF_CUDA_CHECK(
cudaFreeAsync(updated_values_, ctx->stream()->As<ep::CudaStream>()->cuda_stream()));
}
void OnEmbeddingFusedUpdatePutStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
// do nothing
}
const void* EmbeddingFusedUpdatePutUniqueEmbeddings(int64_t iter) override {
CHECK_EQ(iter_, iter);
CHECK(has_lookup_values_);
return lookup_values_;
}
void OnEmbeddingFusedUpdatePutEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {
// do nothing
}
void SetIdFinalNumUnique(uint32_t final_num_unique, int64_t iter) override {
std::unique_lock<std::mutex> lock(mutex_);
int64_t index = iter % kRingBufferSize;
id_statistics_vec_.at(index).final_num_unique = final_num_unique;
id_statistics_vec_.at(index).iter = iter;
}
void SetIdNumUniqueMatrix(const std::vector<uint32_t>& num_unique_matrix, int64_t iter) override {
std::unique_lock<std::mutex> lock(mutex_);
int64_t index = iter % kRingBufferSize;
id_statistics_vec_.at(index).num_unique_matrix = num_unique_matrix;
id_statistics_vec_.at(index).iter = iter;
}
uint32_t GetIdNumUnique(int64_t iter) override {
std::unique_lock<std::mutex> lock(mutex_);
int64_t index = iter % kRingBufferSize;
const IdStatistics& statistics = id_statistics_vec_.at(index);
CHECK_EQ(statistics.iter, iter)
<< "saved iter: " << statistics.iter << " current iter: " << iter;
return statistics.final_num_unique;
}
const std::vector<uint32_t>& GetIdNumUniqueMatrix(int64_t iter) override {
std::unique_lock<std::mutex> lock(mutex_);
int64_t index = iter % kRingBufferSize;
const IdStatistics& statistics = id_statistics_vec_.at(index);
CHECK_EQ(statistics.iter, iter)
<< "saved iter: " << statistics.iter << " current iter: " << iter;
return statistics.num_unique_matrix;
}
private:
void* lookup_values_;
size_t lookup_values_size_;
bool has_lookup_values_;
void* lookup_embeddings_;
size_t lookup_embeddings_size_;
bool has_lookup_embeddings_;
void* updated_values_;
int64_t iter_;
std::vector<IdStatistics> id_statistics_vec_;
int device_index_{};
cudaMemPool_t mem_pool_{};
std::mutex mutex_;
};
#endif
class StaticTmpBufferAllocator final : public TmpBufferAllocator {
public:
OF_DISALLOW_COPY_AND_MOVE(StaticTmpBufferAllocator);
StaticTmpBufferAllocator(void* ptr, size_t size) : ptr_(ptr), offset_(0), size_(size) {}
~StaticTmpBufferAllocator() override = default;
void Allocate(void** ptr, size_t size) override {
CHECK(ptr_ != nullptr);
CHECK_GE(offset_, 0);
size_t aligned_size = GetCudaAlignedSize(size);
CHECK_LE(offset_ + aligned_size, size_);
*ptr = reinterpret_cast<char*>(ptr_) + offset_;
offset_ += aligned_size;
}
void Free(void* ptr) override {
// do nothing
}
private:
void* ptr_;
int64_t offset_;
size_t size_;
};
class StaticAllocationEmbeddingState final : public EmbeddingState {
public:
OF_DISALLOW_COPY_AND_MOVE(StaticAllocationEmbeddingState);
StaticAllocationEmbeddingState()
: lookup_unique_values_(nullptr),
lookup_embeddings_(nullptr),
has_lookup_embeddings_(false),
embedding_shuffle_cur_rank_embeddings_(nullptr),
embedding_update_unique_embeddings_(nullptr),
embedding_update_updated_unique_embeddings_(nullptr),
embedding_put_unique_embeddings_(nullptr),
embedding_fused_update_put_unique_embeddings_(nullptr) {
id_statistics_vec_.resize(kRingBufferSize);
}
~StaticAllocationEmbeddingState() override = default;
std::unique_ptr<TmpBufferAllocator> NewTmpBufferAllocator(
user_op::KernelComputeContext* ctx) override {
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
return std::make_unique<StaticTmpBufferAllocator>(tmp_buffer->mut_dptr(),
tmp_buffer->shape_view().elem_cnt());
}
void OnEmbeddingLookupStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
user_op::Tensor* unique_values = ctx->Tensor4ArgNameAndIndex("unique_values", 0);
lookup_unique_values_ = unique_values->mut_dptr();
if (ctx->has_output("embeddings", 0)) {
user_op::Tensor* embeddings = ctx->Tensor4ArgNameAndIndex("embeddings", 0);
has_lookup_embeddings_ = true;
lookup_embeddings_ = embeddings->mut_dptr();
}
}
void* LookupUniqueValues(int64_t iter) override { return lookup_unique_values_; }
void* LookupEmbeddings(int64_t iter) override {
CHECK(has_lookup_embeddings_);
return lookup_embeddings_;
}
void OnEmbeddingLookupEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {
lookup_unique_values_ = nullptr;
lookup_embeddings_ = nullptr;
has_lookup_embeddings_ = false;
}
void OnEmbeddingGatherStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0);
embedding_gather_in_ = in->dptr();
}
const void* EmbeddingGatherIn(int64_t iter) override { return embedding_gather_in_; }
void OnEmbeddingGatherEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {
embedding_gather_in_ = nullptr;
}
void OnEmbeddingShuffleStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
const user_op::Tensor* cur_rank_embeddings =
ctx->Tensor4ArgNameAndIndex("cur_rank_embeddings", 0);
embedding_shuffle_cur_rank_embeddings_ = cur_rank_embeddings->dptr();
}
const void* EmbeddingShuffleCurRankEmbeddings(int64_t iter) override {
return embedding_shuffle_cur_rank_embeddings_;
}
void OnEmbeddingShuffleEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {
embedding_shuffle_cur_rank_embeddings_ = nullptr;
}
void OnEmbeddingUpdateStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
const user_op::Tensor* unique_embeddings = ctx->Tensor4ArgNameAndIndex("unique_embeddings", 0);
user_op::Tensor* updated_unique_embeddings =
ctx->Tensor4ArgNameAndIndex("updated_unique_embeddings", 0);
embedding_update_unique_embeddings_ = unique_embeddings->dptr();
embedding_update_updated_unique_embeddings_ = updated_unique_embeddings->mut_dptr();
}
const void* EmbeddingUpdateUniqueEmbeddings(int64_t iter) override {
return embedding_update_unique_embeddings_;
}
void* EmbeddingUpdateUpdatedUniqueEmbeddings(int64_t iter) override {
return embedding_update_updated_unique_embeddings_;
}
void OnEmbeddingUpdateEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {
embedding_update_unique_embeddings_ = nullptr;
embedding_update_updated_unique_embeddings_ = nullptr;
}
void OnEmbeddingPutStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
const user_op::Tensor* unique_embeddings = ctx->Tensor4ArgNameAndIndex("unique_embeddings", 0);
embedding_put_unique_embeddings_ = unique_embeddings->dptr();
}
const void* EmbeddingPutUniqueEmbeddings(int64_t iter) override {
return embedding_put_unique_embeddings_;
}
void OnEmbeddingPutEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {
embedding_put_unique_embeddings_ = nullptr;
}
void OnEmbeddingFusedUpdatePutStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
const user_op::Tensor* unique_embeddings = ctx->Tensor4ArgNameAndIndex("unique_embeddings", 0);
embedding_fused_update_put_unique_embeddings_ = unique_embeddings->dptr();
}
const void* EmbeddingFusedUpdatePutUniqueEmbeddings(int64_t iter) override {
return embedding_fused_update_put_unique_embeddings_;
}
void OnEmbeddingFusedUpdatePutEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {
embedding_fused_update_put_unique_embeddings_ = nullptr;
}
void SetIdFinalNumUnique(uint32_t final_num_unique, int64_t iter) override {
std::unique_lock<std::mutex> lock(mutex_);
int64_t index = iter % kRingBufferSize;
id_statistics_vec_.at(index).final_num_unique = final_num_unique;
id_statistics_vec_.at(index).iter = iter;
}
void SetIdNumUniqueMatrix(const std::vector<uint32_t>& num_unique_matrix, int64_t iter) override {
std::unique_lock<std::mutex> lock(mutex_);
int64_t index = iter % kRingBufferSize;
id_statistics_vec_.at(index).num_unique_matrix = num_unique_matrix;
id_statistics_vec_.at(index).iter = iter;
}
uint32_t GetIdNumUnique(int64_t iter) override {
std::unique_lock<std::mutex> lock(mutex_);
int64_t index = iter % kRingBufferSize;
const IdStatistics& statistics = id_statistics_vec_.at(index);
CHECK_EQ(statistics.iter, iter)
<< "saved iter: " << statistics.iter << " current iter: " << iter;
return statistics.final_num_unique;
}
const std::vector<uint32_t>& GetIdNumUniqueMatrix(int64_t iter) override {
std::unique_lock<std::mutex> lock(mutex_);
int64_t index = iter % kRingBufferSize;
const IdStatistics& statistics = id_statistics_vec_.at(index);
CHECK_EQ(statistics.iter, iter)
<< "saved iter: " << statistics.iter << " current iter: " << iter;
return statistics.num_unique_matrix;
}
void* lookup_unique_values_;
void* lookup_embeddings_;
bool has_lookup_embeddings_;
const void* embedding_gather_in_;
const void* embedding_shuffle_cur_rank_embeddings_;
const void* embedding_update_unique_embeddings_;
void* embedding_update_updated_unique_embeddings_;
const void* embedding_put_unique_embeddings_;
const void* embedding_fused_update_put_unique_embeddings_;
std::vector<IdStatistics> id_statistics_vec_;
std::mutex mutex_;
};
EmbeddingState* EmbeddingManager::GetEmbeddingState(const std::string& embedding_name,
int64_t rank_id) {
std::pair<std::string, int64_t> map_key = std::make_pair(embedding_name, rank_id);
std::unique_lock<std::mutex> lock(mutex_);
auto it = embedding_state_map_.find(map_key);
// for id shuffle test, not need to create table
if (it == embedding_state_map_.end()) {
LOG(INFO) << "create embedding state: " << embedding_name << "-" << rank_id;
if (UseDynamicMemoryAllocation()) {
#if CUDA_VERSION >= 11020
it =
embedding_state_map_.emplace(map_key, std::make_unique<DynamicAllocationEmbeddingState>())
.first;
#else
UNIMPLEMENTED();
#endif
} else {
it = embedding_state_map_.emplace(map_key, std::make_unique<StaticAllocationEmbeddingState>())
.first;
}
}
return it->second.get();
}
KeyValueStore* EmbeddingManager::GetKeyValueStore(const std::string& embedding_name,
int64_t rank_id) {
......@@ -66,6 +529,22 @@ void EmbeddingManager::CreateKeyValueStore(const KeyValueStoreOptions& key_value
store->ReserveQueryLength(kDefaultMaxQueryLength);
CHECK(key_value_store_map_.emplace(map_key, std::move(store)).second)
<< "Can't create an embedding with same name of an existing embedding, the name: " << name;
if (UseDynamicMemoryAllocation()) {
#if CUDA_VERSION >= 11020
CHECK(embedding_state_map_.emplace(map_key, std::make_unique<DynamicAllocationEmbeddingState>())
.second)
<< "Can't create an embedding state with same name of an existing embedding, the name: "
<< name;
#else
UNIMPLEMENTED();
#endif
} else {
CHECK(embedding_state_map_.emplace(map_key, std::make_unique<StaticAllocationEmbeddingState>())
.second)
<< "Can't create an embedding state with same name of an existing embedding, the name: "
<< name;
}
}
void EmbeddingManager::SaveSnapshot(const std::string& embedding_name, int64_t local_rank_id,
......@@ -101,6 +580,221 @@ void EmbeddingManager::LoadSnapshot(const std::string& embedding_name, int64_t l
constexpr size_t kDefaultMaxQueryLength = 131072;
constexpr int64_t kRingBufferSize = 8;
struct IdStatistics {
IdStatistics() : final_num_unique(0), iter(-1) {}
uint32_t final_num_unique;
std::vector<uint32_t> num_unique_matrix;
int64_t iter;
};
class StaticTmpBufferAllocator final : public TmpBufferAllocator {
public:
OF_DISALLOW_COPY_AND_MOVE(StaticTmpBufferAllocator);
StaticTmpBufferAllocator(void* ptr, size_t size) : ptr_(ptr), offset_(0), size_(size) {}
~StaticTmpBufferAllocator() override = default;
void Allocate(void** ptr, size_t size) override {
CHECK(ptr_ != nullptr);
CHECK_GE(offset_, 0);
size_t aligned_size = GetCudaAlignedSize(size);
CHECK_LE(offset_ + aligned_size, size_);
*ptr = reinterpret_cast<char*>(ptr_) + offset_;
offset_ += aligned_size;
}
void Free(void* ptr) override {
// do nothing
}
private:
void* ptr_;
int64_t offset_;
size_t size_;
};
class StaticAllocationEmbeddingState final : public EmbeddingState {
public:
OF_DISALLOW_COPY_AND_MOVE(StaticAllocationEmbeddingState);
StaticAllocationEmbeddingState()
: lookup_unique_values_(nullptr),
lookup_embeddings_(nullptr),
has_lookup_embeddings_(false),
embedding_shuffle_cur_rank_embeddings_(nullptr),
embedding_update_unique_embeddings_(nullptr),
embedding_update_updated_unique_embeddings_(nullptr),
embedding_put_unique_embeddings_(nullptr),
embedding_fused_update_put_unique_embeddings_(nullptr) {
id_statistics_vec_.resize(kRingBufferSize);
}
~StaticAllocationEmbeddingState() override = default;
std::unique_ptr<TmpBufferAllocator> NewTmpBufferAllocator(
user_op::KernelComputeContext* ctx) override {
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
return std::make_unique<StaticTmpBufferAllocator>(tmp_buffer->mut_dptr(),
tmp_buffer->shape_view().elem_cnt());
}
void OnEmbeddingLookupStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
user_op::Tensor* unique_values = ctx->Tensor4ArgNameAndIndex("unique_values", 0);
lookup_unique_values_ = unique_values->mut_dptr();
if (ctx->has_output("embeddings", 0)) {
user_op::Tensor* embeddings = ctx->Tensor4ArgNameAndIndex("embeddings", 0);
has_lookup_embeddings_ = true;
lookup_embeddings_ = embeddings->mut_dptr();
}
}
void* LookupUniqueValues(int64_t iter) override { return lookup_unique_values_; }
void* LookupEmbeddings(int64_t iter) override {
CHECK(has_lookup_embeddings_);
return lookup_embeddings_;
}
void OnEmbeddingLookupEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {
lookup_unique_values_ = nullptr;
lookup_embeddings_ = nullptr;
has_lookup_embeddings_ = false;
}
void OnEmbeddingGatherStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0);
embedding_gather_in_ = in->dptr();
}
const void* EmbeddingGatherIn(int64_t iter) override { return embedding_gather_in_; }
void OnEmbeddingGatherEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {
embedding_gather_in_ = nullptr;
}
void OnEmbeddingShuffleStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
const user_op::Tensor* cur_rank_embeddings =
ctx->Tensor4ArgNameAndIndex("cur_rank_embeddings", 0);
embedding_shuffle_cur_rank_embeddings_ = cur_rank_embeddings->dptr();
}
const void* EmbeddingShuffleCurRankEmbeddings(int64_t iter) override {
return embedding_shuffle_cur_rank_embeddings_;
}
void OnEmbeddingShuffleEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {
embedding_shuffle_cur_rank_embeddings_ = nullptr;
}
void OnEmbeddingUpdateStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
const user_op::Tensor* unique_embeddings = ctx->Tensor4ArgNameAndIndex("unique_embeddings", 0);
user_op::Tensor* updated_unique_embeddings =
ctx->Tensor4ArgNameAndIndex("updated_unique_embeddings", 0);
embedding_update_unique_embeddings_ = unique_embeddings->dptr();
embedding_update_updated_unique_embeddings_ = updated_unique_embeddings->mut_dptr();
}
const void* EmbeddingUpdateUniqueEmbeddings(int64_t iter) override {
return embedding_update_unique_embeddings_;
}
void* EmbeddingUpdateUpdatedUniqueEmbeddings(int64_t iter) override {
return embedding_update_updated_unique_embeddings_;
}
void OnEmbeddingUpdateEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {
embedding_update_unique_embeddings_ = nullptr;
embedding_update_updated_unique_embeddings_ = nullptr;
}
void OnEmbeddingPutStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
const user_op::Tensor* unique_embeddings = ctx->Tensor4ArgNameAndIndex("unique_embeddings", 0);
embedding_put_unique_embeddings_ = unique_embeddings->dptr();
}
const void* EmbeddingPutUniqueEmbeddings(int64_t iter) override {
return embedding_put_unique_embeddings_;
}
void OnEmbeddingPutEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {
embedding_put_unique_embeddings_ = nullptr;
}
void OnEmbeddingFusedUpdatePutStart(user_op::KernelComputeContext* ctx, int64_t iter) override {
const user_op::Tensor* unique_embeddings = ctx->Tensor4ArgNameAndIndex("unique_embeddings", 0);
embedding_fused_update_put_unique_embeddings_ = unique_embeddings->dptr();
}
const void* EmbeddingFusedUpdatePutUniqueEmbeddings(int64_t iter) override {
return embedding_fused_update_put_unique_embeddings_;
}
void OnEmbeddingFusedUpdatePutEnd(user_op::KernelComputeContext* ctx, int64_t iter) override {
embedding_fused_update_put_unique_embeddings_ = nullptr;
}
void SetIdFinalNumUnique(uint32_t final_num_unique, int64_t iter) override {
std::unique_lock<std::mutex> lock(mutex_);
int64_t index = iter % kRingBufferSize;
id_statistics_vec_.at(index).final_num_unique = final_num_unique;
id_statistics_vec_.at(index).iter = iter;
}
void SetIdNumUniqueMatrix(const std::vector<uint32_t>& num_unique_matrix, int64_t iter) override {
std::unique_lock<std::mutex> lock(mutex_);
int64_t index = iter % kRingBufferSize;
id_statistics_vec_.at(index).num_unique_matrix = num_unique_matrix;
id_statistics_vec_.at(index).iter = iter;
}
uint32_t GetIdNumUnique(int64_t iter) override {
std::unique_lock<std::mutex> lock(mutex_);
int64_t index = iter % kRingBufferSize;
const IdStatistics& statistics = id_statistics_vec_.at(index);
CHECK_EQ(statistics.iter, iter)
<< "saved iter: " << statistics.iter << " current iter: " << iter;
return statistics.final_num_unique;
}
const std::vector<uint32_t>& GetIdNumUniqueMatrix(int64_t iter) override {
std::unique_lock<std::mutex> lock(mutex_);
int64_t index = iter % kRingBufferSize;
const IdStatistics& statistics = id_statistics_vec_.at(index);
CHECK_EQ(statistics.iter, iter)
<< "saved iter: " << statistics.iter << " current iter: " << iter;
return statistics.num_unique_matrix;
}
void* lookup_unique_values_;
void* lookup_embeddings_;
bool has_lookup_embeddings_;
const void* embedding_gather_in_;
const void* embedding_shuffle_cur_rank_embeddings_;
const void* embedding_update_unique_embeddings_;
void* embedding_update_updated_unique_embeddings_;
const void* embedding_put_unique_embeddings_;
const void* embedding_fused_update_put_unique_embeddings_;
std::vector<IdStatistics> id_statistics_vec_;
std::mutex mutex_;
};
EmbeddingState* EmbeddingManager::GetEmbeddingState(const std::string& embedding_name,
int64_t rank_id) {
std::pair<std::string, int64_t> map_key = std::make_pair(embedding_name, rank_id);
std::unique_lock<std::mutex> lock(mutex_);
auto it = embedding_state_map_.find(map_key);
// for id shuffle test, not need to create table
if (it == embedding_state_map_.end()) {
LOG(INFO) << "create embedding state: " << embedding_name << "-" << rank_id;
if (UseDynamicMemoryAllocation()) {
UNIMPLEMENTED();
} else {
it = embedding_state_map_.emplace(map_key, std::make_unique<StaticAllocationEmbeddingState>())
.first;
}
}
return it->second.get();
}
KeyValueStore* EmbeddingManager::GetKeyValueStore(const std::string& embedding_name,
int64_t rank_id) {
std::pair<std::string, int64_t> map_key = std::make_pair(embedding_name, rank_id);
......@@ -141,6 +835,15 @@ void EmbeddingManager::CreateKeyValueStore(const KeyValueStoreOptions& key_value
store->ReserveQueryLength(kDefaultMaxQueryLength);
CHECK(key_value_store_map_.emplace(map_key, std::move(store)).second)
<< "Can't create an embedding with same name of an existing embedding, the name: " << name;
if (UseDynamicMemoryAllocation()) {
UNIMPLEMENTED();
} else {
CHECK(embedding_state_map_.emplace(map_key, std::make_unique<StaticAllocationEmbeddingState>())
.second)
<< "Can't create an embedding state with same name of an existing embedding, the name: "
<< name;
}
}
void EmbeddingManager::SaveSnapshot(const std::string& embedding_name, int64_t local_rank_id,
......@@ -170,7 +873,7 @@ void EmbeddingManager::LoadSnapshot(const std::string& embedding_name, int64_t l
}
}
#endif // WITH_ROCM
#endif
} // namespace embedding
......
......@@ -20,36 +20,149 @@ limitations under the License.
#include "oneflow/core/embedding/key_value_store.h"
#include "oneflow/core/embedding/key_value_store_options.h"
#include "oneflow/core/framework/framework.h"
namespace oneflow {
namespace embedding {
#ifdef WITH_CUDA
class EmbeddingManager final {
inline bool UseDynamicMemoryAllocation() {
static bool use_dynamic_memory_allocation =
ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_USE_DYNAMIC_MEMORY_ALLOCATION", false);
#if CUDA_VERSION >= 11020
return use_dynamic_memory_allocation;
#else
if (use_dynamic_memory_allocation) {
LOG(WARNING)
<< "Dynamic memory allocation only support when cuda_version greater equal than 11.2. ";
}
return false;
#endif
}
inline bool UseEmbeddingShuffleP2PKernel(DataType embedding_dtype, DataType idx_dtype) {
static bool use_embedding_shuffle_p2p_env =
ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_EMBEDDING_SHUFFLE_USE_P2P", false);
static bool add_id_shuffle_copy_out_env =
ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_ADD_ID_SHUFFLE_COPY_OUT", true);
static bool enable_quantized_comm =
ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM", false);
if (use_embedding_shuffle_p2p_env) {
if (embedding_dtype != DataType::kFloat16 || idx_dtype != DataType::kUInt32) {
// p2p kernel only registered kFloat16 and kUint32.
return false;
}
if (!add_id_shuffle_copy_out_env) {
// when not enable id shuffle copy out, the ptrs change every iter.
return false;
}
if (enable_quantized_comm) {
// p2p kernel not support quantize comm.
return false;
}
if (UseDynamicMemoryAllocation()) {
// p2p kernel not support dynamic memory allocation.
return false;
}
}
#if CUDA_VERSION >= 11030
return use_embedding_shuffle_p2p_env;
#else
if (use_embedding_shuffle_p2p_env) {
LOG(WARNING)
<< "embedding shuffle p2p kernel only support when cuda_version greater equal than 11.3. ";
}
return false;
#endif
}
inline bool UseEmbeddingGradientShuffleP2PKernel(DataType embedding_dtype, DataType idx_dtype) {
static bool use_embedding_gradient_shuffle_p2p_env =
ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_EMBEDDING_GRADIENT_SHUFFLE_USE_P2P", false);
static bool add_id_shuffle_copy_out_env =
ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_ADD_ID_SHUFFLE_COPY_OUT", true);
static bool enable_quantized_comm =
ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM", false);
if (use_embedding_gradient_shuffle_p2p_env) {
if (embedding_dtype != DataType::kFloat16 || idx_dtype != DataType::kUInt32) {
// p2p kernel only registered kFloat16 and kUint32.
return false;
}
if (!add_id_shuffle_copy_out_env) {
// when not enable id shuffle copy out, the ptrs change every iter.
return false;
}
if (enable_quantized_comm) {
// p2p kernel not support quantize comm.
return false;
}
if (UseDynamicMemoryAllocation()) {
// p2p kernel not support dynamic memory allocation.
return false;
}
}
#if CUDA_VERSION >= 11030
return use_embedding_gradient_shuffle_p2p_env;
#else
if (use_embedding_gradient_shuffle_p2p_env) {
LOG(WARNING) << "embedding gradient shuffle p2p kernel only support when cuda_version greater "
"equal than 11.3. ";
}
return false;
#endif
}
#if defined(WITH_CUDA) || defined(WITH_ROCM)
class TmpBufferAllocator {
public:
EmbeddingManager() = default;
~EmbeddingManager() = default;
void SaveSnapshot(const std::string& embedding_name, int64_t local_rank_id, int64_t rank_id,
const std::string& snapshot_name);
void LoadSnapshot(const std::string& embedding_name, int64_t local_rank_id, int64_t rank_id,
const std::string& snapshot_name);
TmpBufferAllocator() = default;
virtual ~TmpBufferAllocator() = default;
KeyValueStore* GetKeyValueStore(const std::string& embedding_name, int64_t rank_id);
void CreateKeyValueStore(const KeyValueStoreOptions& options, int64_t local_rank_id,
int64_t rank_id, int64_t world_size);
private:
HashMap<std::pair<std::string, int64_t>, std::unique_ptr<KeyValueStore>> key_value_store_map_;
std::mutex mutex_;
virtual void Allocate(void** ptr, size_t size) = 0;
virtual void Free(void* ptr) = 0;
};
#endif // WITH_CUDA
#ifdef WITH_ROCM
class EmbeddingState {
public:
EmbeddingState() = default;
virtual ~EmbeddingState() = default;
virtual std::unique_ptr<TmpBufferAllocator> NewTmpBufferAllocator(
user_op::KernelComputeContext* ctx) = 0;
virtual void OnEmbeddingLookupStart(user_op::KernelComputeContext* ctx, int64_t iter) = 0;
virtual void* LookupUniqueValues(int64_t iter) = 0;
virtual void* LookupEmbeddings(int64_t iter) = 0;
virtual void OnEmbeddingLookupEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0;
virtual void OnEmbeddingGatherStart(user_op::KernelComputeContext* ctx, int64_t iter) = 0;
virtual const void* EmbeddingGatherIn(int64_t iter) = 0;
virtual void OnEmbeddingGatherEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0;
virtual void OnEmbeddingShuffleStart(user_op::KernelComputeContext* ctx, int64_t iter) = 0;
virtual const void* EmbeddingShuffleCurRankEmbeddings(int64_t iter) = 0;
virtual void OnEmbeddingShuffleEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0;
virtual void OnEmbeddingUpdateStart(user_op::KernelComputeContext* ctx, int64_t iter) = 0;
virtual const void* EmbeddingUpdateUniqueEmbeddings(int64_t iter) = 0;
virtual void* EmbeddingUpdateUpdatedUniqueEmbeddings(int64_t iter) = 0;
virtual void OnEmbeddingUpdateEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0;
virtual void OnEmbeddingPutStart(user_op::KernelComputeContext* ctx, int64_t iter) = 0;
virtual const void* EmbeddingPutUniqueEmbeddings(int64_t iter) = 0;
virtual void OnEmbeddingPutEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0;
virtual void OnEmbeddingFusedUpdatePutStart(user_op::KernelComputeContext* ctx, int64_t iter) = 0;
virtual const void* EmbeddingFusedUpdatePutUniqueEmbeddings(int64_t iter) = 0;
virtual void OnEmbeddingFusedUpdatePutEnd(user_op::KernelComputeContext* ctx, int64_t iter) = 0;
virtual void SetIdFinalNumUnique(uint32_t final_num_unique, int64_t iter) = 0;
virtual void SetIdNumUniqueMatrix(const std::vector<uint32_t>& num_unique_matrix,
int64_t iter) = 0;
virtual uint32_t GetIdNumUnique(int64_t iter) = 0;
virtual const std::vector<uint32_t>& GetIdNumUniqueMatrix(int64_t iter) = 0;
};
class EmbeddingManager final {
public:
......@@ -62,16 +175,17 @@ class EmbeddingManager final {
const std::string& snapshot_name);
KeyValueStore* GetKeyValueStore(const std::string& embedding_name, int64_t rank_id);
EmbeddingState* GetEmbeddingState(const std::string& embedding_name, int64_t rank_id);
void CreateKeyValueStore(const KeyValueStoreOptions& options, int64_t local_rank_id,
int64_t rank_id, int64_t world_size);
private:
HashMap<std::pair<std::string, int64_t>, std::unique_ptr<KeyValueStore>> key_value_store_map_;
HashMap<std::pair<std::string, int64_t>, std::unique_ptr<EmbeddingState>> embedding_state_map_;
std::mutex mutex_;
};
#endif // WITH_ROCM
#endif // WITH_CUDA
} // namespace embedding
} // namespace oneflow
......
......@@ -28,9 +28,9 @@ using Key128 = ulonglong2;
namespace {
template<typename Key, typename Index>
__device__ bool TryGetOrInsert(Key* entry_key, volatile Index* entry_index, Index* table_size,
Key key, Index* out) {
template<typename Key, typename Index, bool dump_dirty_only>
__device__ bool TryGetOrInsert(Key* entry_key, volatile Index* entry_index, bool* entry_dirty_flag,
Index* table_size, Key key, Index* out) {
Key key_hi = (key | 0x1);
Key key_lo = (key & 0x1);
Index index_plus_one = 0;
......@@ -41,6 +41,10 @@ __device__ bool TryGetOrInsert(Key* entry_key, volatile Index* entry_index, Inde
index_plus_one = index + 1;
*entry_index = ((index_plus_one << 1U) | key_lo);
*out = index_plus_one;
if (dump_dirty_only) {
bool entry_flag_val = *entry_dirty_flag;
if (!entry_flag_val) { *entry_dirty_flag = true; }
}
return true;
} else if (old_entry_key == key_hi) {
const Index entry_index_val = *entry_index;
......@@ -48,6 +52,10 @@ __device__ bool TryGetOrInsert(Key* entry_key, volatile Index* entry_index, Inde
// do nothing
} else if ((entry_index_val & 0x1) == key_lo) {
*out = (entry_index_val >> 1U);
if (dump_dirty_only) {
bool entry_flag_val = *entry_dirty_flag;
if (!entry_flag_val) { *entry_dirty_flag = true; }
}
return true;
} else {
return false;
......@@ -59,15 +67,20 @@ __device__ bool TryGetOrInsert(Key* entry_key, volatile Index* entry_index, Inde
return false;
}
template<typename Key, typename Index>
template<typename Key, typename Index, bool dump_dirty_only>
__device__ bool GetOrInsertOne(const size_t capacity, Key* table_keys, Index* table_indices,
Index* table_size, Key key, size_t hash, Index* out) {
bool* table_dirty_flags, Index* table_size, Key key, size_t hash,
Index* out) {
const size_t start_idx = hash % capacity;
for (size_t count = 0; count < capacity; ++count) {
const size_t idx = (start_idx + count) % capacity;
Key* entry_key = table_keys + idx;
Index* entry_index = table_indices + idx;
if (TryGetOrInsert<Key, Index>(entry_key, entry_index, table_size, key, out)) { return true; }
bool* entry_dirty_flag = dump_dirty_only ? table_dirty_flags + idx : nullptr;
if (TryGetOrInsert<Key, Index, dump_dirty_only>(entry_key, entry_index, entry_dirty_flag,
table_size, key, out)) {
return true;
}
}
return false;
}
......@@ -94,15 +107,15 @@ __device__ bool GetOne(const size_t capacity, Key* table_keys, Index* table_indi
return false;
}
template<typename Key, typename Index>
template<typename Key, typename Index, bool dump_dirty_only>
__global__ void OrdinalEncodeKernel(uint64_t capacity, Key* table_keys, Index* table_indices,
Index* table_size, uint32_t num_keys, const Key* keys,
Index* context) {
bool* table_dirty_flags, Index* table_size, uint32_t num_keys,
const Key* keys, Index* context) {
CUDA_1D_KERNEL_LOOP(i, num_keys) {
Key key = keys[i];
uint64_t hash = FullCacheHash()(key);
bool success = GetOrInsertOne<Key, Index>(capacity, table_keys, table_indices, table_size, key,
hash, context + i);
bool success = GetOrInsertOne<Key, Index, dump_dirty_only>(
capacity, table_keys, table_indices, table_dirty_flags, table_size, key, hash, context + i);
assert(success);
}
}
......@@ -117,14 +130,20 @@ __global__ void OrdinalEncodeLookupKernel(uint64_t capacity, Key* table_keys, In
}
}
template<typename Key, typename Index>
template<typename Key, typename Index, bool dump_dirty_only>
__global__ void OrdinalEncodeDumpKernel(const Key* table_keys, const Index* table_indices,
uint64_t start_key_index, uint64_t end_key_index,
uint32_t* n_dumped, Key* keys, Index* context) {
const bool* table_dirty_flags, uint64_t start_key_index,
uint64_t end_key_index, uint32_t* n_dumped, Key* keys,
Index* context) {
CUDA_1D_KERNEL_LOOP(i, (end_key_index - start_key_index)) {
Key entry_key = table_keys[i + start_key_index];
Index entry_index = table_indices[i + start_key_index];
if (entry_index != 0) {
bool dump_flag = (entry_index != 0);
if (dump_dirty_only) {
bool entry_dirty_flag = table_dirty_flags[i + start_key_index];
dump_flag = (dump_flag && entry_dirty_flag);
}
if (dump_flag) {
uint32_t index = cuda::atomic::Add(n_dumped, static_cast<uint32_t>(1));
keys[index] = ((entry_key ^ 0x1) | (entry_index & 0x1));
context[index] = (entry_index >> 1U);
......@@ -177,7 +196,11 @@ __global__ void EncodeLookupKernel(uint32_t value_length, const Elem* cache_valu
batch_start += global_n_warp * warp_size) {
const uint32_t batch_n_key = min(n_keys - batch_start, warp_size);
if (lane_id == 0) { batch_n_missing[warp_id] = 0; }
__syncwarp();
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
const uint32_t key_offset = batch_start + lane_id;
if (key_offset < n_keys) {
const Key key = keys[batch_start + lane_id];
......@@ -191,14 +214,22 @@ __global__ void EncodeLookupKernel(uint32_t value_length, const Elem* cache_valu
batch_missing_indices[warp_id][batch_missing_idx] = key_offset;
}
}
__syncwarp();
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
const uint32_t batch_n_missing_t = batch_n_missing[warp_id];
if (lane_id == 0) {
const uint32_t old_n_missing =
cuda::atomic::Add(n_missing, static_cast<uint32_t>(batch_n_missing_t));
batch_n_missing[warp_id] = old_n_missing;
}
__syncwarp();
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
if (lane_id < batch_n_missing_t) {
missing_keys[batch_n_missing[warp_id] + lane_id] = batch_missing_keys[warp_id][lane_id];
missing_indices[batch_n_missing[warp_id] + lane_id] = batch_missing_indices[warp_id][lane_id];
......@@ -212,7 +243,11 @@ __global__ void EncodeLookupKernel(uint32_t value_length, const Elem* cache_valu
cache_values[(row - 1) * value_length + col];
}
}
__syncwarp();
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
}
}
......@@ -252,7 +287,11 @@ __global__ void EncodeLookupMaskKernel(uint32_t value_length, const Elem* __rest
batch_row_ids[warp_id][lane_id] = row;
mask[key_offset] = row > 0;
}
__syncwarp();
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
for (int i = 0; i < batch_n_key; ++i) {
const Key key = batch_keys[warp_id][i];
const Index row = batch_row_ids[warp_id][i];
......@@ -263,7 +302,11 @@ __global__ void EncodeLookupMaskKernel(uint32_t value_length, const Elem* __rest
packed_cache_values[(row - 1) * packed_cols + col];
}
}
__syncwarp();
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
}
}
......@@ -314,7 +357,7 @@ __global__ typename std::enable_if<!std::is_same<Elem, float>::value, void>::typ
FusedHalfUpdateKernel(uint32_t value_length, Elem* cache_values, uint32_t values_elem_cnt,
const Index* context, const Elem* values, const half* update, const float* lr,
float scale) {
__trap();
TRAP();
}
template<typename Key, typename Elem, typename Index>
......@@ -333,33 +376,39 @@ template<typename Key, typename Index>
class OrdinalEncoder {
public:
OF_DISALLOW_COPY_AND_MOVE(OrdinalEncoder);
explicit OrdinalEncoder(uint64_t capacity, float load_factor)
: capacity_(capacity), table_capacity_(capacity / load_factor) {
OF_CUDA_CHECK(cudaGetDevice(&device_index_));
OF_CUDA_CHECK(cudaMalloc(&table_size_, sizeof(Index)));
explicit OrdinalEncoder(uint64_t capacity, float load_factor, bool if_dump_dirty)
: capacity_(capacity),
table_capacity_(capacity / load_factor),
if_dump_dirty_(if_dump_dirty) {
OF_CUDA_CHECK(GPU(GetDevice)(&device_index_));
OF_CUDA_CHECK(GPU(Malloc)(&table_size_, sizeof(Index)));
#ifdef WITH_ROCM
OF_CUDA_CHECK(hipMallocHost(reinterpret_cast<void **>(&table_size_host_), sizeof(Index)));
#else
OF_CUDA_CHECK(cudaMallocHost(&table_size_host_, sizeof(Index)));
OF_CUDA_CHECK(cudaMalloc(&table_keys_, table_capacity_ * sizeof(Key)));
OF_CUDA_CHECK(cudaMalloc(&table_indices_, table_capacity_ * sizeof(Index)));
#endif
OF_CUDA_CHECK(GPU(Malloc)(&table_keys_, table_capacity_ * sizeof(Key)));
OF_CUDA_CHECK(GPU(Malloc)(&table_indices_, table_capacity_ * sizeof(Index)));
if (if_dump_dirty_) {
OF_CUDA_CHECK(GPU(Malloc)(&table_dirty_flags_, table_capacity_ * sizeof(bool)));
}
Clear();
}
~OrdinalEncoder() {
CudaCurrentDeviceGuard guard(device_index_);
OF_CUDA_CHECK(cudaFree(table_size_));
OF_CUDA_CHECK(cudaFreeHost(table_size_host_));
OF_CUDA_CHECK(cudaFree(table_keys_));
OF_CUDA_CHECK(cudaFree(table_indices_));
OF_CUDA_CHECK(GPU(Free)(table_size_));
OF_CUDA_CHECK(GPU(FreeHost)(table_size_host_));
OF_CUDA_CHECK(GPU(Free)(table_keys_));
OF_CUDA_CHECK(GPU(Free)(table_indices_));
if (if_dump_dirty_) { OF_CUDA_CHECK(GPU(Free)(table_dirty_flags_)); }
}
template<bool insert>
template<bool insert, bool dump_dirty_only>
void Encode(ep::Stream* stream, uint32_t num_keys, const Key* keys, Index* context) {
if (insert) {
RUN_CUDA_KERNEL((OrdinalEncodeKernel<Key, Index>), stream, num_keys, table_capacity_,
table_keys_, table_indices_, table_size_, num_keys, keys, context);
OF_CUDA_CHECK(cudaMemcpyAsync(table_size_host_, table_size_, sizeof(Index), cudaMemcpyDefault,
stream->As<ep::CudaStream>()->cuda_stream()));
CHECK_JUST(stream->Sync());
CHECK_LT(*table_size_host_, capacity_)
<< "The number of key is larger than cache size, please enlarge cache_memory_budget. ";
RUN_CUDA_KERNEL((OrdinalEncodeKernel<Key, Index, dump_dirty_only>), stream, num_keys,
table_capacity_, table_keys_, table_indices_, table_dirty_flags_, table_size_,
num_keys, keys, context);
} else {
RUN_CUDA_KERNEL((OrdinalEncodeLookupKernel<Key, Index>), stream, num_keys, table_capacity_,
table_keys_, table_indices_, num_keys, keys, context);
......@@ -368,17 +417,35 @@ class OrdinalEncoder {
void Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index,
uint32_t* n_dumped, Key* keys, Index* context) {
OF_CUDA_CHECK(cudaMemsetAsync(n_dumped, 0, sizeof(uint32_t),
OF_CUDA_CHECK(GPU(MemsetAsync)(n_dumped, 0, sizeof(uint32_t),
stream->As<ep::CudaStream>()->cuda_stream()));
RUN_CUDA_KERNEL((OrdinalEncodeDumpKernel<Key, Index, false>), stream,
end_key_index - start_key_index, table_keys_, table_indices_,
table_dirty_flags_, start_key_index, end_key_index, n_dumped, keys, context);
}
void DumpDirtyOnly(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index,
uint32_t* n_dumped, Key* keys, Index* context) {
OF_CUDA_CHECK(GPU(MemsetAsync)(n_dumped, 0, sizeof(uint32_t),
stream->As<ep::CudaStream>()->cuda_stream()));
RUN_CUDA_KERNEL((OrdinalEncodeDumpKernel<Key, Index>), stream, end_key_index - start_key_index,
table_keys_, table_indices_, start_key_index, end_key_index, n_dumped, keys,
context);
RUN_CUDA_KERNEL((OrdinalEncodeDumpKernel<Key, Index, true>), stream,
end_key_index - start_key_index, table_keys_, table_indices_,
table_dirty_flags_, start_key_index, end_key_index, n_dumped, keys, context);
}
void ClearDirtyFlags() {
if (if_dump_dirty_) {
OF_CUDA_CHECK(GPU(Memset)(table_dirty_flags_, 0, table_capacity_ * sizeof(bool)));
}
}
void Clear() {
OF_CUDA_CHECK(cudaMemset(table_size_, 0, sizeof(Index)));
OF_CUDA_CHECK(cudaMemset(table_keys_, 0, table_capacity_ * sizeof(Key)));
OF_CUDA_CHECK(cudaMemset(table_indices_, 0, table_capacity_ * sizeof(Index)));
OF_CUDA_CHECK(GPU(Memset)(table_size_, 0, sizeof(Index)));
OF_CUDA_CHECK(GPU(Memset)(table_keys_, 0, table_capacity_ * sizeof(Key)));
OF_CUDA_CHECK(GPU(Memset)(table_indices_, 0, table_capacity_ * sizeof(Index)));
if (if_dump_dirty_) {
OF_CUDA_CHECK(GPU(Memset)(table_dirty_flags_, 0, table_capacity_ * sizeof(bool)));
}
}
uint64_t TableCapacity() const { return table_capacity_; }
......@@ -391,8 +458,10 @@ class OrdinalEncoder {
int device_index_{};
Key* table_keys_;
Index* table_indices_;
bool* table_dirty_flags_;
uint64_t capacity_;
uint64_t table_capacity_;
bool if_dump_dirty_;
Index* table_size_{};
Index* table_size_host_{};
};
......@@ -402,17 +471,22 @@ class CacheImpl : public Cache {
public:
OF_DISALLOW_COPY_AND_MOVE(CacheImpl);
explicit CacheImpl(const CacheOptions& options)
: encoder_(options.capacity, options.load_factor),
: if_dump_dirty_(ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_DUMP_DIRTY_ONLY", false)),
encoder_(options.capacity, options.load_factor, if_dump_dirty_),
device_index_(-1),
options_(options),
max_query_length_(0) {
OF_CUDA_CHECK(cudaGetDevice(&device_index_));
OF_CUDA_CHECK(GPU(GetDevice)(&device_index_));
const uint64_t values_size = options.capacity * options.value_size;
if (options.value_memory_kind == CacheOptions::MemoryKind::kDevice) {
OF_CUDA_CHECK(cudaMalloc(&values_, values_size));
OF_CUDA_CHECK(GPU(Malloc)(&values_, values_size));
} else if (options.value_memory_kind == CacheOptions::MemoryKind::kHost) {
if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_DISABLE_NUMA_AWARE_ALLOCATION", false)) {
OF_CUDA_CHECK(cudaMallocHost(&values_, values_size));
#ifdef WITH_ROCM
OF_CUDA_CHECK(hipMallocHost(reinterpret_cast<void **>(&values_), values_size));
#else
OF_CUDA_CHECK(cudaMallocHost(&values_, values_size));
#endif
} else {
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, reinterpret_cast<void**>(&values_),
values_size));
......@@ -425,13 +499,13 @@ class CacheImpl : public Cache {
~CacheImpl() {
CudaCurrentDeviceGuard guard(device_index_);
if (options_.value_memory_kind == CacheOptions::MemoryKind::kDevice) {
OF_CUDA_CHECK(cudaFree(values_));
OF_CUDA_CHECK(GPU(Free)(values_));
} else if (options_.value_memory_kind == CacheOptions::MemoryKind::kHost) {
OF_CUDA_CHECK(cudaFreeHost(values_));
OF_CUDA_CHECK(GPU(FreeHost)(values_));
} else {
UNIMPLEMENTED();
}
if (max_query_length_ > 0) { OF_CUDA_CHECK(cudaFree(encoding_buffer_)); }
if (max_query_length_ > 0) { OF_CUDA_CHECK(GPU(Free)(encoding_buffer_)); }
}
uint64_t Capacity() const override { return options_.capacity; }
......@@ -447,8 +521,8 @@ class CacheImpl : public Cache {
void ReserveQueryLength(uint32_t query_length) override {
CudaCurrentDeviceGuard guard(device_index_);
if (query_length <= max_query_length_) { return; }
if (max_query_length_ > 0) { OF_CUDA_CHECK(cudaFree(encoding_buffer_)); }
OF_CUDA_CHECK(cudaMalloc(&encoding_buffer_, query_length * sizeof(uint64_t)));
if (max_query_length_ > 0) { OF_CUDA_CHECK(GPU(Free)(encoding_buffer_)); }
OF_CUDA_CHECK(GPU(Malloc)(&encoding_buffer_, query_length * sizeof(uint64_t)));
max_query_length_ = query_length;
}
......@@ -465,15 +539,19 @@ class CacheImpl : public Cache {
void Put(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values,
uint32_t* n_evicted, void* evicted_keys, void* evicted_values) override;
void FusedHalfUpdatePut(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values,
const void* update, const float* lr, float scale, uint32_t* n_evicted,
void* evicted_keys, void* evicted_values) override;
void Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index,
uint32_t* n_dumped, void* keys, void* values) override;
void ClearDirtyFlags() override;
void Clear() override;
private:
bool if_dump_dirty_;
OrdinalEncoder<Key, Index> encoder_;
int device_index_;
uint32_t num_elem_per_value_{};
......@@ -488,10 +566,16 @@ void CacheImpl<Key, Elem, Index, pack_size>::Test(ep::Stream* stream, uint32_t n
const void* keys, uint32_t* n_missing,
void* missing_keys, uint32_t* missing_indices) {
OF_CUDA_CHECK(
cudaMemsetAsync(n_missing, 0, sizeof(uint32_t), stream->As<ep::CudaStream>()->cuda_stream()));
GPU(MemsetAsync)(n_missing, 0, sizeof(uint32_t), stream->As<ep::CudaStream>()->cuda_stream()));
if (n_keys == 0) { return; }
CHECK_LE(n_keys, max_query_length_);
encoder_.template Encode<false>(stream, n_keys, static_cast<const Key*>(keys), encoding_buffer_);
if (if_dump_dirty_) {
encoder_.template Encode<false, true>(stream, n_keys, static_cast<const Key*>(keys),
encoding_buffer_);
} else {
encoder_.template Encode<false, false>(stream, n_keys, static_cast<const Key*>(keys),
encoding_buffer_);
}
const uint32_t values_elem_cnt = n_keys * num_elem_per_value_;
RUN_CUDA_KERNEL((LookupKernel<Key, Elem, Index, false>), stream, values_elem_cnt,
num_elem_per_value_, values_, values_elem_cnt, static_cast<const Key*>(keys),
......@@ -505,7 +589,7 @@ void CacheImpl<Key, Elem, Index, pack_size>::Get(ep::Stream* stream, uint32_t n_
uint32_t* n_missing, void* missing_keys,
uint32_t* missing_indices) {
OF_CUDA_CHECK(
cudaMemsetAsync(n_missing, 0, sizeof(uint32_t), stream->As<ep::CudaStream>()->cuda_stream()));
GPU(MemsetAsync)(n_missing, 0, sizeof(uint32_t), stream->As<ep::CudaStream>()->cuda_stream()));
if (n_keys == 0) { return; }
CHECK_LE(n_keys, max_query_length_);
constexpr uint32_t block_size = 128;
......@@ -539,11 +623,15 @@ void CacheImpl<Key, Elem, Index, pack_size>::Put(ep::Stream* stream, uint32_t n_
const void* keys, const void* values,
uint32_t* n_evicted, void* evicted_keys,
void* evicted_values) {
OF_CUDA_CHECK(
cudaMemsetAsync(n_evicted, 0, sizeof(uint32_t), stream->As<ep::CudaStream>()->cuda_stream()));
if (n_keys == 0) { return; }
CHECK_LE(n_keys, max_query_length_);
encoder_.template Encode<true>(stream, n_keys, static_cast<const Key*>(keys), encoding_buffer_);
if (if_dump_dirty_) {
encoder_.template Encode<true, true>(stream, n_keys, static_cast<const Key*>(keys),
encoding_buffer_);
} else {
encoder_.template Encode<true, false>(stream, n_keys, static_cast<const Key*>(keys),
encoding_buffer_);
}
const uint32_t values_elem_cnt = n_keys * num_elem_per_value_;
RUN_CUDA_KERNEL((UpdateKernel<Elem, Index, pack_size>), stream, values_elem_cnt / pack_size,
num_elem_per_value_, values_, values_elem_cnt, encoding_buffer_,
......@@ -555,28 +643,43 @@ void CacheImpl<Key, Elem, Index, pack_size>::FusedHalfUpdatePut(
ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values, const void* update,
const float* lr, float scale, uint32_t* n_evicted, void* evicted_keys, void* evicted_values) {
if (!std::is_same<Elem, float>::value) { UNIMPLEMENTED(); }
OF_CUDA_CHECK(
cudaMemsetAsync(n_evicted, 0, sizeof(uint32_t), stream->As<ep::CudaStream>()->cuda_stream()));
if (n_keys == 0) { return; }
CHECK_LE(n_keys, max_query_length_);
encoder_.template Encode<true>(stream, n_keys, static_cast<const Key*>(keys), encoding_buffer_);
if (if_dump_dirty_) {
encoder_.template Encode<true, true>(stream, n_keys, static_cast<const Key*>(keys),
encoding_buffer_);
} else {
encoder_.template Encode<true, false>(stream, n_keys, static_cast<const Key*>(keys),
encoding_buffer_);
}
const uint32_t values_elem_cnt = n_keys * num_elem_per_value_;
RUN_CUDA_KERNEL((FusedHalfUpdateKernel<Elem, Index, pack_size>), stream,
values_elem_cnt / pack_size, num_elem_per_value_, values_, values_elem_cnt,
encoding_buffer_, static_cast<const Elem*>(values),
static_cast<const half*>(update), lr, scale);
}
template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::Dump(ep::Stream* stream, uint64_t start_key_index,
uint64_t end_key_index, uint32_t* n_dumped,
void* keys, void* values) {
encoder_.Dump(stream, start_key_index, end_key_index, n_dumped, static_cast<Key*>(keys),
encoding_buffer_);
if (if_dump_dirty_) {
encoder_.DumpDirtyOnly(stream, start_key_index, end_key_index, n_dumped,
static_cast<Key*>(keys), encoding_buffer_);
} else {
encoder_.Dump(stream, start_key_index, end_key_index, n_dumped, static_cast<Key*>(keys),
encoding_buffer_);
}
RUN_CUDA_KERNEL((DumpValueKernel<Key, Elem, Index>), stream,
num_elem_per_value_ * (end_key_index - start_key_index), num_elem_per_value_,
n_dumped, encoding_buffer_, values_, static_cast<Elem*>(values));
}
template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::ClearDirtyFlags() {
encoder_.ClearDirtyFlags();
}
template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::Clear() {
encoder_.Clear();
......
......@@ -23,16 +23,11 @@ namespace oneflow {
namespace embedding {
#ifdef WITH_CUDA
#if defined(WITH_CUDA) || defined(WITH_ROCM)
std::unique_ptr<Cache> NewFullCache(const CacheOptions& options);
#endif // WITH_CUDA
#ifdef WITH_ROCM
std::unique_ptr<Cache> NewFullCache(const CacheOptions& options);
#endif // WITH_ROCM
} // namespace embedding
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "hip/hip_runtime.h"
#include "oneflow/core/embedding/full_cache.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/embedding/hash_functions.hip.h"
#include "oneflow/core/hip/atomic.hip.h"
namespace oneflow {
namespace embedding {
using Key32 = unsigned int;
using Key64 = unsigned long long int;
using Key128 = ulonglong2;
namespace {
template<typename Key, typename Index>
__device__ bool TryGetOrInsert(Key* entry_key, volatile Index* entry_index, Index* table_size,
Key key, Index* out) {
Key key_hi = (key | 0x1);
Key key_lo = (key & 0x1);
Index index_plus_one = 0;
Key old_entry_key = cuda::atomic::CAS(entry_key, static_cast<Key>(0), key_hi);
while (index_plus_one == 0) {
if (old_entry_key == static_cast<Key>(0)) {
Index index = cuda::atomic::Add(table_size, static_cast<Index>(1));
index_plus_one = index + 1;
*entry_index = ((index_plus_one << 1U) | key_lo);
*out = index_plus_one;
return true;
} else if (old_entry_key == key_hi) {
const Index entry_index_val = *entry_index;
if (entry_index_val == 0) {
// do nothing
} else if ((entry_index_val & 0x1) == key_lo) {
*out = (entry_index_val >> 1U);
return true;
} else {
return false;
}
} else {
return false;
}
}
return false;
}
template<typename Key, typename Index>
__device__ bool GetOrInsertOne(const size_t capacity, Key* table_keys, Index* table_indices,
Index* table_size, Key key, size_t hash, Index* out) {
const size_t start_idx = hash % capacity;
for (size_t count = 0; count < capacity; ++count) {
const size_t idx = (start_idx + count) % capacity;
Key* entry_key = table_keys + idx;
Index* entry_index = table_indices + idx;
if (TryGetOrInsert<Key, Index>(entry_key, entry_index, table_size, key, out)) { return true; }
}
return false;
}
template<typename Key, typename Index>
__device__ bool GetOne(const size_t capacity, Key* table_keys, Index* table_indices, Key key,
size_t hash, Index* out) {
const size_t start_idx = hash % capacity;
for (size_t count = 0; count < capacity; ++count) {
const size_t idx = (start_idx + count) % capacity;
Key entry_key = table_keys[idx];
Key entry_index = table_indices[idx];
Key key_hi = (key | 0x1);
Key key_lo = (key & 0x1);
if (entry_key == 0) { break; }
if (entry_key == key_hi) {
if ((entry_index & 0x1) == key_lo) {
*out = (entry_index >> 1U);
return true;
}
}
}
*out = 0;
return false;
}
template<typename Key, typename Index>
__global__ void OrdinalEncodeKernel(uint64_t capacity, Key* table_keys, Index* table_indices,
Index* table_size, uint32_t num_keys, const Key* keys,
Index* context) {
CUDA_1D_KERNEL_LOOP(i, num_keys) {
Key key = keys[i];
uint64_t hash = FullCacheHash()(key);
bool success = GetOrInsertOne<Key, Index>(capacity, table_keys, table_indices, table_size, key,
hash, context + i);
assert(success);
}
}
template<typename Key, typename Index>
__global__ void OrdinalEncodeLookupKernel(uint64_t capacity, Key* table_keys, Index* table_indices,
uint32_t num_keys, const Key* keys, Index* context) {
CUDA_1D_KERNEL_LOOP(i, num_keys) {
Key key = keys[i];
uint64_t hash = FullCacheHash()(key);
GetOne<Key, Index>(capacity, table_keys, table_indices, key, hash, context + i);
}
}
template<typename Key, typename Index>
__global__ void OrdinalEncodeDumpKernel(const Key* table_keys, const Index* table_indices,
uint64_t start_key_index, uint64_t end_key_index,
uint32_t* n_dumped, Key* keys, Index* context) {
CUDA_1D_KERNEL_LOOP(i, (end_key_index - start_key_index)) {
Key entry_key = table_keys[i + start_key_index];
Index entry_index = table_indices[i + start_key_index];
if (entry_index != 0) {
uint32_t index = cuda::atomic::Add(n_dumped, static_cast<uint32_t>(1));
keys[index] = ((entry_key ^ 0x1) | (entry_index & 0x1));
context[index] = (entry_index >> 1U);
}
}
}
template<typename Key, typename Elem, typename Index, bool return_value>
__global__ void LookupKernel(uint32_t value_length, const Elem* cache_values,
uint32_t values_elem_cnt, const Key* keys, const Index* context,
Elem* values, uint32_t* n_missing, Key* missing_keys,
uint32_t* missing_indices) {
CUDA_1D_KERNEL_LOOP(i, values_elem_cnt) {
const uint64_t key_id = i / value_length;
const uint64_t ctx = context[key_id];
const uint64_t row_id = ctx - 1;
const uint64_t col_id = i - key_id * value_length;
if (ctx == 0) {
const Key missing_key = keys[key_id];
if (col_id == 0) {
const uint32_t old_n_missing = cuda::atomic::Add(n_missing, static_cast<uint32_t>(1));
missing_keys[old_n_missing] = missing_key;
missing_indices[old_n_missing] = key_id;
}
continue;
}
if (return_value) { values[i] = cache_values[row_id * value_length + col_id]; }
}
}
template<typename Key, typename Elem, typename Index, uint32_t block_size>
__global__ void EncodeLookupKernel(uint32_t value_length, const Elem* cache_values,
uint32_t values_elem_cnt, const Key* keys, const Index* context,
Elem* values, uint32_t* n_missing, Key* missing_keys,
uint32_t* missing_indices, const size_t capacity,
Key* table_keys, Index* table_indices) {
constexpr uint32_t warp_size = 32;
constexpr uint32_t n_warp_per_block = block_size / warp_size;
const uint32_t warp_id = threadIdx.x / warp_size;
const uint32_t lane_id = threadIdx.x % warp_size;
const uint32_t global_warp_id = blockIdx.x * n_warp_per_block + warp_id;
const uint32_t global_n_warp = gridDim.x * n_warp_per_block;
const uint32_t n_keys = values_elem_cnt / value_length;
__shared__ Key batch_keys[n_warp_per_block][warp_size];
__shared__ Index batch_row_ids[n_warp_per_block][warp_size];
__shared__ Key batch_missing_keys[n_warp_per_block][warp_size];
__shared__ uint32_t batch_missing_indices[n_warp_per_block][warp_size];
__shared__ uint32_t batch_n_missing[n_warp_per_block];
for (uint32_t batch_start = global_warp_id * warp_size; batch_start < n_keys;
batch_start += global_n_warp * warp_size) {
const uint32_t batch_n_key = min(n_keys - batch_start, warp_size);
if (lane_id == 0) { batch_n_missing[warp_id] = 0; }
__syncthreads();
const uint32_t key_offset = batch_start + lane_id;
if (key_offset < n_keys) {
const Key key = keys[batch_start + lane_id];
const uint64_t hash = FullCacheHash()(key);
Index row;
GetOne<Key, Index>(capacity, table_keys, table_indices, key, hash, &row);
batch_row_ids[warp_id][lane_id] = row;
if (row == 0) {
const uint32_t batch_missing_idx = atomicAdd(batch_n_missing + warp_id, 1);
batch_missing_keys[warp_id][batch_missing_idx] = key;
batch_missing_indices[warp_id][batch_missing_idx] = key_offset;
}
}
__syncthreads();
const uint32_t batch_n_missing_t = batch_n_missing[warp_id];
if (lane_id == 0) {
const uint32_t old_n_missing =
cuda::atomic::Add(n_missing, static_cast<uint32_t>(batch_n_missing_t));
batch_n_missing[warp_id] = old_n_missing;
}
__syncthreads();
if (lane_id < batch_n_missing_t) {
missing_keys[batch_n_missing[warp_id] + lane_id] = batch_missing_keys[warp_id][lane_id];
missing_indices[batch_n_missing[warp_id] + lane_id] = batch_missing_indices[warp_id][lane_id];
}
for (int i = 0; i < batch_n_key; ++i) {
const Key key = batch_keys[warp_id][i];
const Index row = batch_row_ids[warp_id][i];
if (row == 0) { continue; }
for (int col = lane_id; col < value_length; col += warp_size) {
values[(batch_start + i) * value_length + col] =
cache_values[(row - 1) * value_length + col];
}
}
__syncthreads();
}
}
template<typename T, size_t pack_size>
struct alignas(sizeof(T) * pack_size) Pack {
T elem[pack_size];
};
template<typename Key, typename Elem, typename Index, uint32_t block_size, uint32_t pack_size>
__global__ void EncodeLookupMaskKernel(uint32_t value_length, const Elem* __restrict__ cache_values,
uint32_t values_elem_cnt, const Key* __restrict__ keys,
const Index* __restrict__ context, Elem* __restrict__ values,
uint8_t* __restrict__ mask, const size_t capacity,
Key* __restrict__ table_keys,
Index* __restrict__ table_indices) {
const uint32_t packed_cols = value_length / pack_size;
auto* packed_values = reinterpret_cast<Pack<Elem, pack_size>*>(values);
const auto* packed_cache_values = reinterpret_cast<const Pack<Elem, pack_size>*>(cache_values);
constexpr uint32_t warp_size = 32;
constexpr uint32_t n_warp_per_block = block_size / warp_size;
const uint32_t warp_id = threadIdx.x / warp_size;
const uint32_t lane_id = threadIdx.x % warp_size;
const uint32_t global_warp_id = blockIdx.x * n_warp_per_block + warp_id;
const uint32_t global_n_warp = gridDim.x * n_warp_per_block;
const uint32_t n_keys = values_elem_cnt / value_length;
__shared__ Key batch_keys[n_warp_per_block][warp_size];
__shared__ Index batch_row_ids[n_warp_per_block][warp_size];
for (uint32_t batch_start = global_warp_id * warp_size; batch_start < n_keys;
batch_start += global_n_warp * warp_size) {
const uint32_t batch_n_key = min(n_keys - batch_start, warp_size);
const uint32_t key_offset = batch_start + lane_id;
if (key_offset < n_keys) {
const Key key = keys[batch_start + lane_id];
const uint64_t hash = FullCacheHash()(key);
Index row;
GetOne<Key, Index>(capacity, table_keys, table_indices, key, hash, &row);
batch_row_ids[warp_id][lane_id] = row;
mask[key_offset] = row > 0;
}
__syncthreads();
for (int i = 0; i < batch_n_key; ++i) {
const Key key = batch_keys[warp_id][i];
const Index row = batch_row_ids[warp_id][i];
if (row == 0) { continue; }
#pragma unroll 4
for (int col = lane_id; col < packed_cols; col += warp_size) {
packed_values[(batch_start + i) * packed_cols + col] =
packed_cache_values[(row - 1) * packed_cols + col];
}
}
__syncthreads();
}
}
template<typename Elem, typename Index, size_t pack_size>
__global__ void UpdateKernel(uint32_t value_length, Elem* cache_values, uint32_t values_elem_cnt,
const Index* context, const Elem* values) {
const int packed_values_elem_cnt = values_elem_cnt / pack_size;
const uint32_t packed_elem_cnt = value_length / pack_size;
auto* packed_cache_values = reinterpret_cast<Pack<Elem, pack_size>*>(cache_values);
auto* packed_values = reinterpret_cast<const Pack<Elem, pack_size>*>(values);
CUDA_1D_KERNEL_LOOP(i, packed_values_elem_cnt) {
const uint64_t key_id = i / packed_elem_cnt;
const uint64_t ctx = context[key_id];
if (ctx == 0) { continue; }
const uint64_t row_id = ctx - 1;
const uint64_t col_id = i - key_id * packed_elem_cnt;
packed_cache_values[row_id * packed_elem_cnt + col_id] = packed_values[i];
}
}
template<typename Elem, typename Index, size_t pack_size>
__global__ typename std::enable_if<std::is_same<Elem, float>::value, void>::type
FusedHalfUpdateKernel(uint32_t value_length, Elem* __restrict__ cache_values,
uint32_t values_elem_cnt, const Index* __restrict__ context,
const Elem* __restrict__ values, const half* __restrict__ update,
const float* __restrict__ lr, float scale) {
const int packed_values_elem_cnt = values_elem_cnt / pack_size;
const uint32_t packed_elem_cnt = value_length / pack_size;
auto* packed_cache_values = reinterpret_cast<Pack<Elem, pack_size>*>(cache_values);
auto* packed_values = reinterpret_cast<const Pack<Elem, pack_size>*>(values);
auto* packed_update = reinterpret_cast<const Pack<half, pack_size>*>(update);
const float alpha = -*lr * scale;
CUDA_1D_KERNEL_LOOP(i, packed_values_elem_cnt) {
const uint64_t key_id = i / packed_elem_cnt;
const uint64_t ctx = context[key_id];
if (ctx == 0) { continue; }
const uint64_t row_id = ctx - 1;
const uint64_t col_id = i - key_id * packed_elem_cnt;
Pack<Elem, pack_size> m = packed_values[i];
Pack<half, pack_size> u = packed_update[i];
for (size_t j = 0; j < pack_size; ++j) { m.elem[j] += static_cast<Elem>(u.elem[j]) * alpha; }
packed_cache_values[row_id * packed_elem_cnt + col_id] = m;
}
}
template<typename Elem, typename Index, size_t pack_size>
__global__ typename std::enable_if<!std::is_same<Elem, float>::value, void>::type
FusedHalfUpdateKernel(uint32_t value_length, Elem* cache_values, uint32_t values_elem_cnt,
const Index* context, const Elem* values, const half* update, const float* lr,
float scale) {
asm volatile("s_trap 0;");
}
template<typename Key, typename Elem, typename Index>
__global__ void DumpValueKernel(uint32_t value_length, const uint32_t* n_dumped,
const Index* context, const Elem* cache_values, Elem* values) {
CUDA_1D_KERNEL_LOOP(i, *n_dumped * value_length) {
const uint64_t key_id = i / value_length;
const uint64_t ctx = context[key_id];
const uint64_t row_id = ctx - 1;
const uint64_t col_id = i - key_id * value_length;
values[i] = cache_values[row_id * value_length + col_id];
}
}
template<typename Key, typename Index>
class OrdinalEncoder {
public:
OF_DISALLOW_COPY_AND_MOVE(OrdinalEncoder);
explicit OrdinalEncoder(uint64_t capacity, float load_factor)
: capacity_(capacity), table_capacity_(capacity / load_factor) {
OF_CUDA_CHECK(hipGetDevice(&device_index_));
OF_CUDA_CHECK(hipMalloc(&table_size_, sizeof(Index)));
OF_CUDA_CHECK(hipMallocHost(reinterpret_cast<void **>(&table_size_host_), sizeof(Index)));
OF_CUDA_CHECK(hipMalloc(&table_keys_, table_capacity_ * sizeof(Key)));
OF_CUDA_CHECK(hipMalloc(&table_indices_, table_capacity_ * sizeof(Index)));
Clear();
}
~OrdinalEncoder() {
CudaCurrentDeviceGuard guard(device_index_);
OF_CUDA_CHECK(hipFree(table_size_));
OF_CUDA_CHECK(hipHostFree(table_size_host_));
OF_CUDA_CHECK(hipFree(table_keys_));
OF_CUDA_CHECK(hipFree(table_indices_));
}
template<bool insert>
void Encode(ep::Stream* stream, uint32_t num_keys, const Key* keys, Index* context) {
if (insert) {
RUN_CUDA_KERNEL((OrdinalEncodeKernel<Key, Index>), stream, num_keys, table_capacity_,
table_keys_, table_indices_, table_size_, num_keys, keys, context);
OF_CUDA_CHECK(hipMemcpyAsync(table_size_host_, table_size_, sizeof(Index), hipMemcpyDefault,
stream->As<ep::CudaStream>()->cuda_stream()));
CHECK_JUST(stream->Sync());
CHECK_LT(*table_size_host_, capacity_)
<< "The number of key is larger than cache size, please enlarge cache_memory_budget. ";
} else {
RUN_CUDA_KERNEL((OrdinalEncodeLookupKernel<Key, Index>), stream, num_keys, table_capacity_,
table_keys_, table_indices_, num_keys, keys, context);
}
}
void Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index,
uint32_t* n_dumped, Key* keys, Index* context) {
OF_CUDA_CHECK(hipMemsetAsync(n_dumped, 0, sizeof(uint32_t),
stream->As<ep::CudaStream>()->cuda_stream()));
RUN_CUDA_KERNEL((OrdinalEncodeDumpKernel<Key, Index>), stream, end_key_index - start_key_index,
table_keys_, table_indices_, start_key_index, end_key_index, n_dumped, keys,
context);
}
void Clear() {
OF_CUDA_CHECK(hipMemset(table_size_, 0, sizeof(Index)));
OF_CUDA_CHECK(hipMemset(table_keys_, 0, table_capacity_ * sizeof(Key)));
OF_CUDA_CHECK(hipMemset(table_indices_, 0, table_capacity_ * sizeof(Index)));
}
uint64_t TableCapacity() const { return table_capacity_; }
Key* table_keys() const { return table_keys_; }
Index* table_indices() const { return table_indices_; }
private:
int device_index_{};
Key* table_keys_;
Index* table_indices_;
uint64_t capacity_;
uint64_t table_capacity_;
Index* table_size_{};
Index* table_size_host_{};
};
template<typename Key, typename Elem, typename Index, size_t pack_size>
class CacheImpl : public Cache {
public:
OF_DISALLOW_COPY_AND_MOVE(CacheImpl);
explicit CacheImpl(const CacheOptions& options)
: encoder_(options.capacity, options.load_factor),
device_index_(-1),
options_(options),
max_query_length_(0) {
OF_CUDA_CHECK(hipGetDevice(&device_index_));
const uint64_t values_size = options.capacity * options.value_size;
if (options.value_memory_kind == CacheOptions::MemoryKind::kDevice) {
OF_CUDA_CHECK(hipMalloc(&values_, values_size));
} else if (options.value_memory_kind == CacheOptions::MemoryKind::kHost) {
if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_DISABLE_NUMA_AWARE_ALLOCATION", false)) {
OF_CUDA_CHECK(hipMallocHost(reinterpret_cast<void **>(&values_), values_size));
} else {
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, reinterpret_cast<void**>(&values_),
values_size));
}
} else {
UNIMPLEMENTED();
}
num_elem_per_value_ = options_.value_size / sizeof(Elem);
}
~CacheImpl() {
CudaCurrentDeviceGuard guard(device_index_);
if (options_.value_memory_kind == CacheOptions::MemoryKind::kDevice) {
OF_CUDA_CHECK(hipFree(values_));
} else if (options_.value_memory_kind == CacheOptions::MemoryKind::kHost) {
OF_CUDA_CHECK(hipHostFree(values_));
} else {
UNIMPLEMENTED();
}
if (max_query_length_ > 0) { OF_CUDA_CHECK(hipFree(encoding_buffer_)); }
}
uint64_t Capacity() const override { return options_.capacity; }
uint64_t DumpCapacity() const override { return encoder_.TableCapacity(); }
uint32_t KeySize() const override { return options_.key_size; }
uint32_t ValueSize() const override { return options_.value_size; }
DataType ValueType() const override { return options_.value_type; }
uint32_t MaxQueryLength() const override { return max_query_length_; }
void ReserveQueryLength(uint32_t query_length) override {
CudaCurrentDeviceGuard guard(device_index_);
if (query_length <= max_query_length_) { return; }
if (max_query_length_ > 0) { OF_CUDA_CHECK(hipFree(encoding_buffer_)); }
OF_CUDA_CHECK(hipMalloc(&encoding_buffer_, query_length * sizeof(uint64_t)));
max_query_length_ = query_length;
}
CacheOptions::Policy Policy() const override { return CacheOptions::Policy::kFull; }
void Test(ep::Stream* stream, uint32_t n_keys, const void* keys, uint32_t* n_missing,
void* missing_keys, uint32_t* missing_indices) override;
void Get(ep::Stream* stream, uint32_t n_keys, const void* keys, void* values, uint32_t* n_missing,
void* missing_keys, uint32_t* missing_indices) override;
void Get(ep::Stream* stream, uint32_t n_keys, const void* keys, void* values,
uint8_t* mask) override;
void Put(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values,
uint32_t* n_evicted, void* evicted_keys, void* evicted_values) override;
void FusedHalfUpdatePut(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values,
const void* update, const float* lr, float scale, uint32_t* n_evicted,
void* evicted_keys, void* evicted_values) override;
void Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index,
uint32_t* n_dumped, void* keys, void* values) override;
void Clear() override;
private:
OrdinalEncoder<Key, Index> encoder_;
int device_index_;
uint32_t num_elem_per_value_{};
Elem* values_;
Index* encoding_buffer_{};
CacheOptions options_;
uint32_t max_query_length_;
};
template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::Test(ep::Stream* stream, uint32_t n_keys,
const void* keys, uint32_t* n_missing,
void* missing_keys, uint32_t* missing_indices) {
OF_CUDA_CHECK(
hipMemsetAsync(n_missing, 0, sizeof(uint32_t), stream->As<ep::CudaStream>()->cuda_stream()));
if (n_keys == 0) { return; }
CHECK_LE(n_keys, max_query_length_);
encoder_.template Encode<false>(stream, n_keys, static_cast<const Key*>(keys), encoding_buffer_);
const uint32_t values_elem_cnt = n_keys * num_elem_per_value_;
RUN_CUDA_KERNEL((LookupKernel<Key, Elem, Index, false>), stream, values_elem_cnt,
num_elem_per_value_, values_, values_elem_cnt, static_cast<const Key*>(keys),
encoding_buffer_, nullptr, n_missing, static_cast<Key*>(missing_keys),
missing_indices);
}
template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::Get(ep::Stream* stream, uint32_t n_keys,
const void* keys, void* values,
uint32_t* n_missing, void* missing_keys,
uint32_t* missing_indices) {
OF_CUDA_CHECK(
hipMemsetAsync(n_missing, 0, sizeof(uint32_t), stream->As<ep::CudaStream>()->cuda_stream()));
if (n_keys == 0) { return; }
CHECK_LE(n_keys, max_query_length_);
constexpr uint32_t block_size = 128;
uint32_t grid_size = (n_keys + block_size - 1) / block_size;
const uint32_t values_elem_cnt = n_keys * num_elem_per_value_;
EncodeLookupKernel<Key, Elem, Index, block_size>
<<<grid_size, block_size, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
num_elem_per_value_, values_, values_elem_cnt, static_cast<const Key*>(keys),
encoding_buffer_, static_cast<Elem*>(values), n_missing, static_cast<Key*>(missing_keys),
missing_indices, encoder_.TableCapacity(), encoder_.table_keys(),
encoder_.table_indices());
}
template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::Get(ep::Stream* stream, uint32_t n_keys,
const void* keys, void* values, uint8_t* mask) {
if (n_keys == 0) { return; }
CHECK_LE(n_keys, max_query_length_);
constexpr uint32_t block_size = 128;
uint32_t grid_size = (n_keys + block_size - 1) / block_size;
const uint32_t values_elem_cnt = n_keys * num_elem_per_value_;
EncodeLookupMaskKernel<Key, Elem, Index, block_size, pack_size>
<<<grid_size, block_size, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
num_elem_per_value_, values_, values_elem_cnt, static_cast<const Key*>(keys),
encoding_buffer_, static_cast<Elem*>(values), mask, encoder_.TableCapacity(),
encoder_.table_keys(), encoder_.table_indices());
}
template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::Put(ep::Stream* stream, uint32_t n_keys,
const void* keys, const void* values,
uint32_t* n_evicted, void* evicted_keys,
void* evicted_values) {
OF_CUDA_CHECK(
hipMemsetAsync(n_evicted, 0, sizeof(uint32_t), stream->As<ep::CudaStream>()->cuda_stream()));
if (n_keys == 0) { return; }
CHECK_LE(n_keys, max_query_length_);
encoder_.template Encode<true>(stream, n_keys, static_cast<const Key*>(keys), encoding_buffer_);
const uint32_t values_elem_cnt = n_keys * num_elem_per_value_;
RUN_CUDA_KERNEL((UpdateKernel<Elem, Index, pack_size>), stream, values_elem_cnt / pack_size,
num_elem_per_value_, values_, values_elem_cnt, encoding_buffer_,
static_cast<const Elem*>(values));
}
template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::FusedHalfUpdatePut(
ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values, const void* update,
const float* lr, float scale, uint32_t* n_evicted, void* evicted_keys, void* evicted_values) {
if (!std::is_same<Elem, float>::value) { UNIMPLEMENTED(); }
OF_CUDA_CHECK(
hipMemsetAsync(n_evicted, 0, sizeof(uint32_t), stream->As<ep::CudaStream>()->cuda_stream()));
if (n_keys == 0) { return; }
CHECK_LE(n_keys, max_query_length_);
encoder_.template Encode<true>(stream, n_keys, static_cast<const Key*>(keys), encoding_buffer_);
const uint32_t values_elem_cnt = n_keys * num_elem_per_value_;
RUN_CUDA_KERNEL((FusedHalfUpdateKernel<Elem, Index, pack_size>), stream,
values_elem_cnt / pack_size, num_elem_per_value_, values_, values_elem_cnt,
encoding_buffer_, static_cast<const Elem*>(values),
static_cast<const half*>(update), lr, scale);
}
template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::Dump(ep::Stream* stream, uint64_t start_key_index,
uint64_t end_key_index, uint32_t* n_dumped,
void* keys, void* values) {
encoder_.Dump(stream, start_key_index, end_key_index, n_dumped, static_cast<Key*>(keys),
encoding_buffer_);
RUN_CUDA_KERNEL((DumpValueKernel<Key, Elem, Index>), stream,
num_elem_per_value_ * (end_key_index - start_key_index), num_elem_per_value_,
n_dumped, encoding_buffer_, values_, static_cast<Elem*>(values));
}
template<typename Key, typename Elem, typename Index, size_t pack_size>
void CacheImpl<Key, Elem, Index, pack_size>::Clear() {
encoder_.Clear();
}
template<typename Key, typename Index>
std::unique_ptr<Cache> DispatchValueType(const CacheOptions& options) {
if (options.value_type == DataType::kFloat) {
const size_t value_elem_cnt = options.value_size / sizeof(float);
const size_t half_warp = 16;
if (value_elem_cnt % 4 == 0 && value_elem_cnt / 4 > half_warp) {
return std::unique_ptr<Cache>(new CacheImpl<Key, float, Index, 4>(options));
} else if (value_elem_cnt % 2 == 0 && value_elem_cnt / 2 > half_warp) {
return std::unique_ptr<Cache>(new CacheImpl<Key, float, Index, 2>(options));
} else {
return std::unique_ptr<Cache>(new CacheImpl<Key, float, Index, 1>(options));
}
} else if (options.value_size % sizeof(ulonglong2) == 0) {
return std::unique_ptr<Cache>(new CacheImpl<Key, ulonglong2, Index, 1>(options));
} else if (options.value_size % sizeof(uint64_t) == 0) {
return std::unique_ptr<Cache>(new CacheImpl<Key, uint64_t, Index, 1>(options));
} else if (options.value_size % sizeof(uint32_t) == 0) {
return std::unique_ptr<Cache>(new CacheImpl<Key, uint32_t, Index, 1>(options));
} else if (options.value_size % sizeof(uint16_t) == 0) {
return std::unique_ptr<Cache>(new CacheImpl<Key, uint16_t, Index, 1>(options));
} else {
return std::unique_ptr<Cache>(new CacheImpl<Key, uint8_t, Index, 1>(options));
}
}
template<typename Index>
std::unique_ptr<Cache> DispatchKeyType(const CacheOptions& options) {
if (options.key_size == sizeof(Key32)) {
return DispatchValueType<Key32, Index>(options);
} else if (options.key_size == sizeof(Key64)) {
return DispatchValueType<Key64, Index>(options);
} else {
UNIMPLEMENTED();
return nullptr;
}
}
std::unique_ptr<Cache> DispatchIndexType(const CacheOptions& options) {
const int64_t table_capacity = static_cast<double>(options.capacity) / options.load_factor;
if (table_capacity >= (1ULL << 31ULL)) {
return DispatchKeyType<uint64_t>(options);
} else {
return DispatchKeyType<uint32_t>(options);
}
}
} // namespace
std::unique_ptr<Cache> NewFullCache(const CacheOptions& options) {
return DispatchIndexType(options);
}
} // namespace embedding
} // namespace oneflow
\ No newline at end of file
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EMBEDDING_HASH_FUNCTION_HIP_H_
#define ONEFLOW_CORE_EMBEDDING_HASH_FUNCTION_HIP_H_
#include <stdint.h>
#include "oneflow/core/common/data_type.h"
namespace oneflow {
namespace embedding {
namespace {
// From https://github.com/Cyan4973/xxHash/blob/dev/xxhash.h
static const uint64_t PRIME64_1 =
0x9E3779B185EBCA87ULL; // 0b1001111000110111011110011011000110000101111010111100101010000111
static const uint64_t PRIME64_2 =
0xC2B2AE3D27D4EB4FULL; // 0b1100001010110010101011100011110100100111110101001110101101001111
static const uint64_t PRIME64_3 =
0x165667B19E3779F9ULL; // 0b0001011001010110011001111011000110011110001101110111100111111001
static const uint64_t PRIME64_4 =
0x85EBCA77C2B2AE63ULL; // 0b1000010111101011110010100111011111000010101100101010111001100011
static const uint64_t PRIME64_5 =
0x27D4EB2F165667C5ULL; // 0b0010011111010100111010110010111100010110010101100110011111000101
#define XXH_rotl64(x, r) (((x) << (r)) | ((x) >> (64 - (r))))
OF_DEVICE_FUNC uint64_t XXH64_round(uint64_t acc, uint64_t input) {
acc += input * PRIME64_2;
acc = XXH_rotl64(acc, 31);
acc *= PRIME64_1;
return acc;
}
OF_DEVICE_FUNC uint64_t xxh64_uint64(uint64_t v, uint64_t seed) {
uint64_t acc = seed + PRIME64_5;
acc += sizeof(uint64_t);
acc = acc ^ XXH64_round(0, v);
acc = XXH_rotl64(acc, 27) * PRIME64_1;
acc = acc + PRIME64_4;
acc ^= (acc >> 33);
acc = acc * PRIME64_2;
acc = acc ^ (acc >> 29);
acc = acc * PRIME64_3;
acc = acc ^ (acc >> 32);
return acc;
}
static const size_t kShardingHashSeed = 1;
static const size_t kLocalUniqueHashSeed = 2;
static const size_t kGlobalUniqueHashSeed = 3;
static const size_t kFullCacheHashSeed = 4;
static const size_t kLruCacheHashSeed = 5;
} // namespace
struct ShardingHash {
OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kShardingHashSeed); }
OF_DEVICE_FUNC size_t operator()(uint32_t v) { return xxh64_uint64(v, kShardingHashSeed); }
OF_DEVICE_FUNC size_t operator()(int32_t v) {
return xxh64_uint64(static_cast<uint32_t>(v), kShardingHashSeed);
}
OF_DEVICE_FUNC size_t operator()(int64_t v) {
return xxh64_uint64(static_cast<uint64_t>(v), kShardingHashSeed);
}
};
struct LocalUniqueHash {
OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kLocalUniqueHashSeed); }
};
struct GlobalUniqueHash {
OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kGlobalUniqueHashSeed); }
};
struct FullCacheHash {
OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kFullCacheHashSeed); }
};
struct LruCacheHash {
OF_DEVICE_FUNC size_t operator()(uint64_t v) { return xxh64_uint64(v, kLruCacheHashSeed); }
};
} // namespace embedding
} // namespace oneflow
#endif // ONEFLOW_CORE_EMBEDDING_HASH_FUNCTION_HIP_H_
\ No newline at end of file
......@@ -263,6 +263,242 @@ TEST(MockKeyValueStore, Mock) {
#endif // WITH_CUDA
#ifdef WITH_ROCM
std::string CreateTempDirectory() {
const char* tmp_env = getenv("TMPDIR");
const char* tmp_dir = tmp_env == nullptr ? "/tmp" : tmp_env;
std::string tpl = std::string(tmp_dir) + "/test_kv_XXXXXX";
char* path = mkdtemp(const_cast<char*>(tpl.c_str()));
PCHECK(path != nullptr);
return std::string(path);
}
bool HasCudaDevice() {
int device_count = 0;
if (hipGetDeviceCount(&device_count) != hipSuccess) { return false; }
if (device_count <= 0) { return false; }
return true;
}
void TestKeyValueStore(KeyValueStore* store, size_t num_embeddings, size_t test_embeddings,
size_t embedding_vec_size) {
auto device = Singleton<ep::DeviceManagerRegistry>::Get()->GetDevice(DeviceType::kCUDA, 0);
ep::Stream* stream = device->CreateStream();
store->SaveSnapshot("init");
uint64_t* keys = nullptr;
float* values = nullptr;
float* values1 = nullptr;
uint64_t* keys_host = nullptr;
float* values_host = nullptr;
uint64_t* context = nullptr;
uint32_t* n_missing = nullptr;
uint32_t* host_n_missing = nullptr;
uint64_t* missing_keys = nullptr;
uint32_t* missing_indices = nullptr;
size_t keys_size = sizeof(uint64_t) * num_embeddings;
size_t values_size = sizeof(float) * embedding_vec_size * num_embeddings;
size_t context_size = sizeof(uint64_t) * num_embeddings;
const size_t batch_size = 128;
OF_CUDA_CHECK(hipMalloc(&keys, keys_size));
OF_CUDA_CHECK(hipMalloc(&values, values_size));
OF_CUDA_CHECK(hipMalloc(&values1, values_size));
OF_CUDA_CHECK(hipMalloc(&context, context_size));
OF_CUDA_CHECK(hipMallocHost(reinterpret_cast<void **>(&keys_host), keys_size));
OF_CUDA_CHECK(hipMallocHost(reinterpret_cast<void **>(&values_host), values_size));
OF_CUDA_CHECK(hipMallocHost(reinterpret_cast<void **>(&host_n_missing), sizeof(uint32_t)));
OF_CUDA_CHECK(hipMalloc(&missing_keys, batch_size * sizeof(uint64_t)));
OF_CUDA_CHECK(hipMalloc(&missing_indices, batch_size * sizeof(uint32_t)));
OF_CUDA_CHECK(hipMalloc(&n_missing, sizeof(uint32_t)));
for (size_t i = 0; i < num_embeddings; ++i) {
uint64_t key = i + 1;
keys_host[i] = key;
for (size_t j = 0; j < embedding_vec_size; j++) {
values_host[i * embedding_vec_size + j] = key;
}
}
OF_CUDA_CHECK(hipMemcpy(keys, keys_host, keys_size, hipMemcpyDefault));
OF_CUDA_CHECK(hipMemcpy(values, values_host, values_size, hipMemcpyDefault));
store->Put(stream, 0, keys, values);
OF_CUDA_CHECK(hipDeviceSynchronize());
OF_CUDA_CHECK(hipGetLastError());
for (size_t offset = 0; offset < test_embeddings; offset += batch_size) {
const size_t num_keys = std::min(batch_size, test_embeddings - offset);
store->Get(stream, num_keys, keys + offset, values1 + offset * embedding_vec_size, n_missing,
missing_indices);
OF_CUDA_CHECK(hipMemcpy(host_n_missing, n_missing, sizeof(uint32_t), hipMemcpyDefault));
OF_CUDA_CHECK(hipDeviceSynchronize());
ASSERT_EQ(*host_n_missing, num_keys);
store->Put(stream, num_keys, keys + offset, values + offset * embedding_vec_size);
}
OF_CUDA_CHECK(hipDeviceSynchronize());
store->SaveSnapshot("final");
OF_CUDA_CHECK(hipMemset(values_host, 0, values_size));
OF_CUDA_CHECK(hipMemset(values, 0, values_size));
for (size_t offset = 0; offset < test_embeddings; offset += batch_size) {
const size_t num_keys = std::min(batch_size, test_embeddings - offset);
store->Get(stream, num_keys, keys + offset, values + offset * embedding_vec_size, n_missing,
missing_indices);
OF_CUDA_CHECK(hipMemcpy(host_n_missing, n_missing, sizeof(uint32_t), hipMemcpyDefault));
OF_CUDA_CHECK(hipDeviceSynchronize());
ASSERT_EQ(*host_n_missing, 0);
}
OF_CUDA_CHECK(hipMemcpy(values_host, values, values_size, hipMemcpyDefault));
OF_CUDA_CHECK(hipDeviceSynchronize());
for (size_t i = 0; i < test_embeddings; ++i) {
uint64_t key = keys_host[i];
for (size_t j = 0; j < embedding_vec_size; j++) {
ASSERT_EQ(values_host[i * embedding_vec_size + j], key);
}
}
store->LoadSnapshot("init");
for (size_t offset = 0; offset < test_embeddings; offset += batch_size) {
const size_t num_keys = std::min(batch_size, test_embeddings - offset);
store->Get(stream, num_keys, keys + offset, values1 + offset * embedding_vec_size, n_missing,
missing_indices);
OF_CUDA_CHECK(hipMemcpy(host_n_missing, n_missing, sizeof(uint32_t), hipMemcpyDefault));
OF_CUDA_CHECK(hipDeviceSynchronize());
ASSERT_EQ(*host_n_missing, num_keys);
}
store->LoadSnapshot("final");
OF_CUDA_CHECK(hipMemset(values_host, 0, values_size));
OF_CUDA_CHECK(hipMemset(values, 0, values_size));
for (size_t offset = 0; offset < test_embeddings; offset += batch_size) {
const size_t num_keys = std::min(batch_size, test_embeddings - offset);
store->Get(stream, num_keys, keys + offset, values + offset * embedding_vec_size, n_missing,
missing_indices);
OF_CUDA_CHECK(hipMemcpy(host_n_missing, n_missing, sizeof(uint32_t), hipMemcpyDefault));
OF_CUDA_CHECK(hipDeviceSynchronize());
ASSERT_EQ(*host_n_missing, 0);
}
OF_CUDA_CHECK(hipMemcpy(values_host, values, values_size, hipMemcpyDefault));
OF_CUDA_CHECK(hipDeviceSynchronize());
for (size_t i = 0; i < test_embeddings; ++i) {
uint64_t key = keys_host[i];
for (size_t j = 0; j < embedding_vec_size; j++) {
ASSERT_EQ(values_host[i * embedding_vec_size + j], key);
}
}
OF_CUDA_CHECK(hipDeviceSynchronize());
OF_CUDA_CHECK(hipGetLastError());
OF_CUDA_CHECK(hipFree(keys));
OF_CUDA_CHECK(hipFree(values));
OF_CUDA_CHECK(hipFree(values1));
OF_CUDA_CHECK(hipHostFree(keys_host));
OF_CUDA_CHECK(hipHostFree(values_host));
OF_CUDA_CHECK(hipHostFree(host_n_missing));
OF_CUDA_CHECK(hipFree(n_missing));
OF_CUDA_CHECK(hipFree(missing_keys));
OF_CUDA_CHECK(hipFree(missing_indices));
CHECK_JUST(stream->Sync());
device->DestroyStream(stream);
}
TEST(PersistentTableKeyValueStore, PersistentTableKeyValueStore) {
if (!HasCudaDevice()) { return; }
Singleton<ep::DeviceManagerRegistry>::New();
PersistentTableKeyValueStoreOptions options{};
uint32_t value_length = 128;
std::string path = CreateTempDirectory();
options.table_options.path = path;
options.table_options.value_size = value_length * sizeof(float);
options.table_options.key_size = GetSizeOfDataType(DataType::kUInt64);
options.table_options.physical_block_size = 512;
std::unique_ptr<KeyValueStore> store = NewPersistentTableKeyValueStore(options);
store->ReserveQueryLength(128);
TestKeyValueStore(store.get(), 1024, 1024, value_length);
store.reset();
PosixFile::RecursiveDelete(path);
Singleton<ep::DeviceManagerRegistry>::Delete();
}
// TEST(CachedKeyValueStore, LRU) {
// if (!HasCudaDevice()) { return; }
// Singleton<ep::DeviceManagerRegistry>::New();
// PersistentTableKeyValueStoreOptions store_options{};
// std::string path = CreateTempDirectory();
// store_options.table_options.path = path;
// uint32_t value_length = 128;
// store_options.table_options.value_size = value_length * sizeof(float);
// store_options.table_options.key_size = GetSizeOfDataType(DataType::kUInt64);
// store_options.table_options.physical_block_size = 512;
// std::unique_ptr<KeyValueStore> store = NewPersistentTableKeyValueStore(store_options);
// CacheOptions cache_options{};
// cache_options.policy = CacheOptions::Policy::kLRU;
// cache_options.value_memory_kind = CacheOptions::MemoryKind::kDevice;
// cache_options.value_size = 512;
// cache_options.capacity = 512;
// cache_options.key_size = 8;
// std::unique_ptr<Cache> cache = NewCache(cache_options);
// std::unique_ptr<KeyValueStore> cached_store =
// NewCachedKeyValueStore(std::move(store), std::move(cache));
// cached_store->ReserveQueryLength(128);
// TestKeyValueStore(cached_store.get(), 1024, 1024, value_length);
// cached_store.reset();
// PosixFile::RecursiveDelete(path);
// Singleton<ep::DeviceManagerRegistry>::Delete();
// }
TEST(CachedKeyValueStore, Full) {
if (!HasCudaDevice()) { return; }
Singleton<ep::DeviceManagerRegistry>::New();
PersistentTableKeyValueStoreOptions store_options{};
std::string path = CreateTempDirectory();
store_options.table_options.path = path;
uint32_t value_length = 128;
store_options.table_options.value_size = value_length * sizeof(float);
store_options.table_options.key_size = GetSizeOfDataType(DataType::kUInt64);
store_options.table_options.physical_block_size = 512;
std::unique_ptr<KeyValueStore> store = NewPersistentTableKeyValueStore(store_options);
CacheOptions cache_options{};
cache_options.policy = CacheOptions::Policy::kFull;
cache_options.value_memory_kind = CacheOptions::MemoryKind::kHost;
cache_options.value_size = 512;
cache_options.capacity = 1024 * 2;
cache_options.key_size = 8;
std::unique_ptr<Cache> cache = NewCache(cache_options);
std::unique_ptr<KeyValueStore> cached_store =
NewCachedKeyValueStore(std::move(store), std::move(cache));
cached_store->ReserveQueryLength(128);
TestKeyValueStore(cached_store.get(), 1024, 1024, value_length);
cached_store.reset();
PosixFile::RecursiveDelete(path);
Singleton<ep::DeviceManagerRegistry>::Delete();
}
TEST(MockKeyValueStore, Mock) {
if (!HasCudaDevice()) { return; }
Singleton<ep::DeviceManagerRegistry>::New();
MockKeyValueStoreOptions store_options{};
std::string path = CreateTempDirectory();
uint32_t value_length = 128;
store_options.value_size = value_length * sizeof(float);
store_options.key_size = GetSizeOfDataType(DataType::kUInt64);
std::unique_ptr<KeyValueStore> store = NewMockKeyValueStore(store_options);
store->ReserveQueryLength(128);
TestKeyValueStore(store.get(), 1024, 1024, value_length);
store.reset();
PosixFile::RecursiveDelete(path);
Singleton<ep::DeviceManagerRegistry>::Delete();
}
#endif // WITH_ROCM
} // namespace
} // namespace embedding
......
......@@ -20,9 +20,14 @@ limitations under the License.
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/embedding/hash_functions.cuh"
#include <new>
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#else
#include <cuda.h>
#endif
#if CUDA_VERSION >= 11000 && ((!defined(__CUDA_ARCH__)) || (__CUDA_ARCH__ >= 700))
#if CUDA_VERSION >= 11000 && ((!defined(__CUDA_ARCH__)) || (__CUDA_ARCH__ >= 700)) \
&& !(defined(__clang__) && defined(__CUDA__))
#include <cuda/std/semaphore>
#endif
......@@ -32,7 +37,11 @@ namespace embedding {
namespace {
#ifdef WITH_ROCM
constexpr int kWarpSize = 64;
#else
constexpr int kWarpSize = 32;
#endif
constexpr int kNumWarpPerBlock = 4;
constexpr int kBlockSize = kNumWarpPerBlock * kWarpSize;
constexpr uint32_t kFullMask = 0xFFFFFFFFU;
......@@ -69,11 +78,19 @@ class WarpMutexAtomicImpl {
;
}
__threadfence();
__syncwarp();
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
}
__device__ void Unlock(const ThreadContext& thread_ctx) {
__syncwarp();
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
__threadfence();
if (thread_ctx.lane_id == 0) { atomicExch(&flag_, 0); }
}
......@@ -82,7 +99,8 @@ class WarpMutexAtomicImpl {
int32_t flag_;
};
#if CUDA_VERSION >= 11000 && ((!defined(__CUDA_ARCH__)) || (__CUDA_ARCH__ >= 700))
#if CUDA_VERSION >= 11000 && ((!defined(__CUDA_ARCH__)) || (__CUDA_ARCH__ >= 700)) \
&& !(defined(__clang__) && defined(__CUDA__))
class WarpMutexSemaphoreImpl {
public:
......@@ -92,11 +110,19 @@ class WarpMutexSemaphoreImpl {
__device__ void Lock(const ThreadContext& thread_ctx) {
if (thread_ctx.lane_id == 0) { semaphore_.acquire(); }
__syncwarp();
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
}
__device__ void Unlock(const ThreadContext& thread_ctx) {
__syncwarp();
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
if (thread_ctx.lane_id == 0) { semaphore_.release(); }
}
......@@ -118,19 +144,20 @@ struct LruCacheContext {
};
__global__ void InitCacheSetMutex(uint32_t n_set, void* mutex) {
#if CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 700
#if CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 700 && !(defined(__clang__) && defined(__CUDA__))
using WarpMutex = WarpMutexSemaphoreImpl;
#else
using WarpMutex = WarpMutexAtomicImpl;
#endif // CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 700
#endif // CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 700 && !(defined(__clang__) &&
// defined(__CUDA__))
const uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n_set) { new (reinterpret_cast<WarpMutex*>(mutex) + idx) WarpMutex; }
}
template<typename Key, typename Elem>
void ClearLruCacheContext(LruCacheContext<Key, Elem>* ctx) {
OF_CUDA_CHECK(cudaMemset(ctx->keys, 0, ctx->n_set * kWarpSize * sizeof(Key)));
OF_CUDA_CHECK(cudaMemset(ctx->ages, 0, ctx->n_set * kWarpSize * sizeof(uint8_t)));
OF_CUDA_CHECK(GPU(Memset)(ctx->keys, 0, ctx->n_set * kWarpSize * sizeof(Key)));
OF_CUDA_CHECK(GPU(Memset)(ctx->ages, 0, ctx->n_set * kWarpSize * sizeof(uint8_t)));
InitCacheSetMutex<<<(ctx->n_set - 1 + 256) / 256, 256>>>(ctx->n_set, ctx->mutex);
}
......@@ -141,11 +168,13 @@ void InitLruCacheContext(const CacheOptions& options, LruCacheContext<Key, Elem>
const size_t lines_size_per_set = kWarpSize * line_size * sizeof(Elem);
const size_t ages_size_per_set = kWarpSize * sizeof(uint8_t);
int device = 0;
OF_CUDA_CHECK(cudaGetDevice(&device));
OF_CUDA_CHECK(GPU(GetDevice)(&device));
int major = 0;
#ifdef WITH_CUDA
OF_CUDA_CHECK(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));
#endif
size_t mutex_size_per_set = 0;
#if CUDA_VERSION >= 11000
#if CUDA_VERSION >= 11000 && !(defined(__clang__) && defined(__CUDA__))
if (major >= 7) {
#if !defined(__CUDA_ARCH__)
mutex_size_per_set = sizeof(WarpMutexSemaphoreImpl);
......@@ -157,19 +186,23 @@ void InitLruCacheContext(const CacheOptions& options, LruCacheContext<Key, Elem>
}
#else
mutex_size_per_set = sizeof(WarpMutexAtomicImpl);
#endif // CUDA_VERSION >= 11000
#endif // CUDA_VERSION >= 11000 && !(defined(__clang__) && defined(__CUDA__))
const size_t n_set = (options.capacity - 1 + kWarpSize) / kWarpSize;
CHECK_GT(n_set, 0);
ctx->n_set = n_set;
ctx->line_size = line_size;
const size_t keys_size = n_set * keys_size_per_set;
OF_CUDA_CHECK(cudaMalloc(&(ctx->keys), keys_size));
OF_CUDA_CHECK(GPU(Malloc)(&(ctx->keys), keys_size));
const size_t lines_size = n_set * lines_size_per_set;
if (options.value_memory_kind == CacheOptions::MemoryKind::kDevice) {
OF_CUDA_CHECK(cudaMalloc(&(ctx->lines), lines_size));
OF_CUDA_CHECK(GPU(Malloc)(&(ctx->lines), lines_size));
} else if (options.value_memory_kind == CacheOptions::MemoryKind::kHost) {
if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_DISABLE_NUMA_AWARE_ALLOCATION", false)) {
#ifdef WITH_ROCM
OF_CUDA_CHECK(hipMallocHost(reinterpret_cast<void **>(&(ctx->lines)), lines_size));
#else
OF_CUDA_CHECK(cudaMallocHost(&(ctx->lines), lines_size));
#endif
} else {
OF_CUDA_CHECK(
NumaAwareCudaMallocHost(device, reinterpret_cast<void**>(&ctx->lines), lines_size));
......@@ -179,45 +212,50 @@ void InitLruCacheContext(const CacheOptions& options, LruCacheContext<Key, Elem>
}
ctx->value_memory_kind = options.value_memory_kind;
const size_t ages_size = n_set * ages_size_per_set;
OF_CUDA_CHECK(cudaMalloc(&(ctx->ages), ages_size));
OF_CUDA_CHECK(GPU(Malloc)(&(ctx->ages), ages_size));
const size_t mutex_size = n_set * mutex_size_per_set;
OF_CUDA_CHECK(cudaMalloc(&(ctx->mutex), mutex_size));
OF_CUDA_CHECK(GPU(Malloc)(&(ctx->mutex), mutex_size));
ClearLruCacheContext(ctx);
}
template<typename Key, typename Elem>
void DestroyLruCacheContext(LruCacheContext<Key, Elem>* ctx) {
OF_CUDA_CHECK(cudaFree(ctx->keys));
OF_CUDA_CHECK(GPU(Free)(ctx->keys));
if (ctx->value_memory_kind == CacheOptions::MemoryKind::kDevice) {
OF_CUDA_CHECK(cudaFree(ctx->lines));
OF_CUDA_CHECK(GPU(Free)(ctx->lines));
} else if (ctx->value_memory_kind == CacheOptions::MemoryKind::kHost) {
OF_CUDA_CHECK(cudaFreeHost(ctx->lines));
OF_CUDA_CHECK(GPU(FreeHost)(ctx->lines));
} else {
UNIMPLEMENTED();
}
OF_CUDA_CHECK(cudaFree(ctx->ages));
OF_CUDA_CHECK(cudaFree(ctx->mutex));
OF_CUDA_CHECK(GPU(Free)(ctx->ages));
OF_CUDA_CHECK(GPU(Free)(ctx->mutex));
}
template<typename Key, typename Elem>
struct SetContext {
#if CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 700
#if CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 700 && !(defined(__clang__) && defined(__CUDA__))
using WarpMutex = WarpMutexSemaphoreImpl;
#else
using WarpMutex = WarpMutexAtomicImpl;
#endif // CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 700
#endif // CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 700 && !(defined(__clang__) &&
// defined(__CUDA__))
__device__ SetContext(const LruCacheContext<Key, Elem>& ctx, uint32_t set_id)
: keys(ctx.keys + set_id * kWarpSize),
mutex(reinterpret_cast<WarpMutex*>(ctx.mutex) + set_id),
ages(ctx.ages + set_id * kWarpSize),
lines(ctx.lines + set_id * kWarpSize * ctx.line_size) {}
lines(ctx.lines + static_cast<size_t>(set_id) * kWarpSize * ctx.line_size) {}
__device__ int Lookup(const ThreadContext& thread_ctx, Key key) {
const Key lane_key = keys[thread_ctx.lane_id];
const int lane_age = ages[thread_ctx.lane_id];
const bool lane_hit = (lane_key == key && lane_age != 0);
#ifdef WITH_ROCM
const unsigned hit_mask = __ballot(lane_hit);
#else
const unsigned hit_mask = __ballot_sync(kFullMask, lane_hit);
#endif
if (hit_mask != 0) {
return __ffs(static_cast<int>(hit_mask)) - 1;
} else {
......@@ -238,19 +276,35 @@ struct SetContext {
int insert_way = -1;
const Key lane_key = keys[thread_ctx.lane_id];
int lane_age = ages[thread_ctx.lane_id];
#ifdef WITH_ROCM
const unsigned hit_mask = __ballot(lane_key == key && lane_age != 0);
#else
const unsigned hit_mask = __ballot_sync(kFullMask, lane_key == key && lane_age != 0);
#endif
if (hit_mask != 0) {
insert_way = __ffs(static_cast<int>(hit_mask)) - 1;
#ifdef WITH_ROCM
const int insert_way_age = __shfl(lane_age, insert_way);
#else
const int insert_way_age = __shfl_sync(kFullMask, lane_age, insert_way);
#endif
if (lane_age > insert_way_age) {
lane_age -= 1;
} else if (thread_ctx.lane_id == insert_way) {
lane_age = kWarpSize;
}
__syncwarp();
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
}
if (insert_way == -1) {
const unsigned valid_mask = __ballot_sync(kFullMask, lane_age != 0);
#ifdef WITH_ROCM
const unsigned valid_mask = __ballot(lane_age != 0);
#else
const unsigned valid_mask = __ballot_sync(kFullMask, lane_age != 0);
#endif
if (valid_mask != kFullMask) {
insert_way = __popc(static_cast<int>(valid_mask));
if (lane_age > 0) {
......@@ -259,7 +313,11 @@ struct SetContext {
lane_age = kWarpSize;
keys[insert_way] = key;
}
__syncwarp();
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
}
}
if (insert_way != -1) { ages[thread_ctx.lane_id] = lane_age; }
......@@ -270,15 +328,28 @@ struct SetContext {
const ThreadContext& thread_ctx, Key key, int* way, Key* evicted_key) {
const Key lane_key = keys[thread_ctx.lane_id];
int lane_age = ages[thread_ctx.lane_id];
#ifdef WITH_ROCM
const int insert_way = __ffs(static_cast<int>(__ballot(lane_age == 1))) - 1;
#else
const int insert_way = __ffs(__ballot_sync(kFullMask, lane_age == 1)) - 1;
#endif
#ifdef WITH_ROCM
*evicted_key = __shfl(lane_key, insert_way);
#else
*evicted_key = __shfl_sync(kFullMask, lane_key, insert_way);
#endif
if (thread_ctx.lane_id == insert_way) {
keys[insert_way] = key;
lane_age = kWarpSize;
} else if (lane_age > 1) {
lane_age -= 1;
}
__syncwarp();
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
ages[thread_ctx.lane_id] = lane_age;
*way = insert_way;
}
......@@ -318,7 +389,11 @@ __global__ void GetKernel(LruCacheContext<Key, Elem> cache_ctx, uint32_t num_key
block_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = key;
block_set_ids[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = set_id;
}
__syncwarp();
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
uint32_t n_warp_missing = 0;
Key warp_missing_key = 0;
uint32_t warp_missing_index = 0;
......@@ -333,7 +408,11 @@ __global__ void GetKernel(LruCacheContext<Key, Elem> cache_ctx, uint32_t num_key
warp_missing_key = key;
warp_missing_index = key_idx;
}
__syncwarp();
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
n_warp_missing += 1;
} else if (!test_only) {
set_ctx.Read(cache_ctx, thread_ctx, way, values + key_idx * cache_ctx.line_size);
......@@ -342,15 +421,31 @@ __global__ void GetKernel(LruCacheContext<Key, Elem> cache_ctx, uint32_t num_key
if (n_warp_missing > 0) {
uint32_t base_missing_idx = 0;
if (thread_ctx.lane_id == 0) { base_missing_idx = atomicAdd(n_missing_keys, n_warp_missing); }
__syncwarp();
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
#ifdef WITH_ROCM
base_missing_idx = __shfl(base_missing_idx, 0);
#else
base_missing_idx = __shfl_sync(kFullMask, base_missing_idx, 0);
#endif
if (thread_ctx.lane_id < n_warp_missing) {
missing_keys[base_missing_idx + thread_ctx.lane_id] = warp_missing_key;
missing_indices[base_missing_idx + thread_ctx.lane_id] = warp_missing_index;
}
__syncwarp();
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
}
__syncwarp();
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
}
}
......@@ -371,7 +466,11 @@ __global__ void PutWithoutEvictingKernel(LruCacheContext<Key, Elem> cache_ctx, u
block_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = key;
block_set_ids[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = set_id;
}
__syncwarp();
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
uint32_t n_warp_missing = 0;
Key warp_missing_key = 0;
uint32_t warp_missing_index = 0;
......@@ -390,7 +489,11 @@ __global__ void PutWithoutEvictingKernel(LruCacheContext<Key, Elem> cache_ctx, u
warp_missing_key = key;
warp_missing_index = key_idx;
}
__syncwarp();
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
n_warp_missing += 1;
}
set_ctx.Unlock(thread_ctx);
......@@ -398,13 +501,25 @@ __global__ void PutWithoutEvictingKernel(LruCacheContext<Key, Elem> cache_ctx, u
if (n_warp_missing > 0) {
uint32_t base_missing_idx = 0;
if (thread_ctx.lane_id == 0) { base_missing_idx = atomicAdd(n_missing, n_warp_missing); }
__syncwarp();
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
#ifdef WITH_ROCM
base_missing_idx = __shfl(base_missing_idx, 0);
#else
base_missing_idx = __shfl_sync(kFullMask, base_missing_idx, 0);
#endif
if (thread_ctx.lane_id < n_warp_missing) {
missing_keys[base_missing_idx + thread_ctx.lane_id] = warp_missing_key;
missing_indices[base_missing_idx + thread_ctx.lane_id] = warp_missing_index;
}
__syncwarp();
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
}
}
}
......@@ -427,7 +542,11 @@ __global__ void EvictKernel(LruCacheContext<Key, Elem> cache_ctx, const Key* key
block_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = key;
block_set_ids[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = set_id;
}
__syncwarp();
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
for (uint32_t i = 0; i < n_batch_keys; ++i) {
const uint32_t key_idx = batch_offset + i;
const Key key = block_keys[thread_ctx.warp_id_in_block][i];
......@@ -438,7 +557,11 @@ __global__ void EvictKernel(LruCacheContext<Key, Elem> cache_ctx, const Key* key
Key evicted_key = 0;
set_ctx.Evict(cache_ctx, thread_ctx, key, &evicted_way, &evicted_key);
if (thread_ctx.lane_id == 0) { evicted_keys[key_idx] = evicted_key; }
__syncwarp();
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
set_ctx.Read(cache_ctx, thread_ctx, evicted_way,
evicted_values + cache_ctx.line_size * key_idx);
set_ctx.Write(cache_ctx, thread_ctx, evicted_way,
......@@ -463,26 +586,52 @@ __global__ void DumpKernel(LruCacheContext<Key, Elem> cache_ctx, size_t start_ke
lane_key = cache_ctx.keys[warp_start_key_index + thread_ctx.lane_id];
lane_age = cache_ctx.ages[warp_start_key_index + thread_ctx.lane_id];
}
__syncwarp();
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
warp_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_key;
warp_ages[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_age;
#ifdef WITH_ROCM
const int key_count = __popc(static_cast<int>(__ballot(lane_age != 0)));
#else
const int key_count = __popc(__ballot_sync(kFullMask, lane_age != 0));
#endif
if (key_count == 0) { continue; }
uint32_t offset = 0;
if (thread_ctx.lane_id == 0) { offset = atomicAdd(n_dumped, key_count); }
#ifdef WITH_ROCM
offset = __shfl(offset, 0);
#else
offset = __shfl_sync(kFullMask, offset, 0);
__syncwarp();
#endif
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
for (uint32_t i = 0; i < kWarpSize; ++i) {
const Key key = warp_keys[thread_ctx.warp_id_in_block][i];
const Key age = warp_ages[thread_ctx.warp_id_in_block][i];
if (age == 0) { continue; }
if (thread_ctx.lane_id == 0) { keys[offset] = key; }
__syncwarp();
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
for (uint32_t j = thread_ctx.lane_id; j < cache_ctx.line_size; j += kWarpSize) {
values[offset * cache_ctx.line_size + j] =
cache_ctx.lines[(warp_start_key_index + i) * cache_ctx.line_size + j];
cache_ctx
.lines[static_cast<size_t>(warp_start_key_index + i) * cache_ctx.line_size + j];
}
__syncwarp();
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
offset += 1;
}
}
......@@ -498,14 +647,14 @@ class LruCache : public Cache {
query_indices_buffer_(nullptr),
query_keys_buffer_(nullptr),
value_type_(options.value_type) {
OF_CUDA_CHECK(cudaGetDevice(&device_index_));
OF_CUDA_CHECK(GPU(GetDevice)(&device_index_));
InitLruCacheContext(options, &ctx_);
}
~LruCache() override {
CudaCurrentDeviceGuard guard(device_index_);
if (max_query_length_ != 0) {
OF_CUDA_CHECK(cudaFree(query_indices_buffer_));
OF_CUDA_CHECK(cudaFree(query_keys_buffer_));
OF_CUDA_CHECK(GPU(Free)(query_indices_buffer_));
OF_CUDA_CHECK(GPU(Free)(query_keys_buffer_));
}
DestroyLruCacheContext(&ctx_);
}
......@@ -520,11 +669,11 @@ class LruCache : public Cache {
CudaCurrentDeviceGuard guard(device_index_);
if (query_length < max_query_length_) { return; }
if (max_query_length_ != 0) {
OF_CUDA_CHECK(cudaFree(query_indices_buffer_));
OF_CUDA_CHECK(cudaFree(query_keys_buffer_));
OF_CUDA_CHECK(GPU(Free)(query_indices_buffer_));
OF_CUDA_CHECK(GPU(Free)(query_keys_buffer_));
}
OF_CUDA_CHECK(cudaMalloc(&query_indices_buffer_, query_length * sizeof(uint32_t)));
OF_CUDA_CHECK(cudaMalloc(&query_keys_buffer_, query_length * sizeof(Key)));
OF_CUDA_CHECK(GPU(Malloc)(&query_indices_buffer_, query_length * sizeof(uint32_t)));
OF_CUDA_CHECK(GPU(Malloc)(&query_keys_buffer_, query_length * sizeof(Key)));
max_query_length_ = query_length;
}
......@@ -534,18 +683,19 @@ class LruCache : public Cache {
void* missing_keys, uint32_t* missing_indices) override {
CHECK_LE(n_keys, max_query_length_);
auto cuda_stream = stream->As<ep::CudaStream>();
OF_CUDA_CHECK(cudaMemsetAsync(n_missing, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));
OF_CUDA_CHECK(GPU(MemsetAsync)(n_missing, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));
if (n_keys == 0) { return; }
cuda_stream->LaunchKernel(GetKernel<Key, Elem, true>, GetLaunchConfig(n_keys), ctx_, n_keys,
static_cast<const Key*>(keys), nullptr, n_missing,
static_cast<Key*>(missing_keys), missing_indices);
}
using Cache::Get;
void Get(ep::Stream* stream, uint32_t n_keys, const void* keys, void* values, uint32_t* n_missing,
void* missing_keys, uint32_t* missing_indices) override {
CHECK_LE(n_keys, max_query_length_);
auto cuda_stream = stream->As<ep::CudaStream>();
OF_CUDA_CHECK(cudaMemsetAsync(n_missing, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));
OF_CUDA_CHECK(GPU(MemsetAsync)(n_missing, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));
if (n_keys == 0) { return; }
cuda_stream->LaunchKernel(GetKernel<Key, Elem, false>, GetLaunchConfig(n_keys), ctx_, n_keys,
static_cast<const Key*>(keys), static_cast<Elem*>(values), n_missing,
......@@ -556,7 +706,7 @@ class LruCache : public Cache {
uint32_t* n_evicted, void* evicted_keys, void* evicted_values) override {
CHECK_LE(n_keys, max_query_length_);
auto cuda_stream = stream->As<ep::CudaStream>();
OF_CUDA_CHECK(cudaMemsetAsync(n_evicted, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));
OF_CUDA_CHECK(GPU(MemsetAsync)(n_evicted, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));
if (n_keys == 0) { return; }
cuda_stream->LaunchKernel(PutWithoutEvictingKernel<Key, Elem>, GetLaunchConfig(n_keys), ctx_,
n_keys, static_cast<const Key*>(keys),
......@@ -571,7 +721,7 @@ class LruCache : public Cache {
void Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index,
uint32_t* n_dumped, void* keys, void* values) override {
auto cuda_stream = stream->As<ep::CudaStream>();
OF_CUDA_CHECK(cudaMemsetAsync(n_dumped, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));
OF_CUDA_CHECK(GPU(MemsetAsync)(n_dumped, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));
const uint64_t max_dump_keys = end_key_index - start_key_index;
cuda_stream->LaunchKernel(
DumpKernel<Key, Elem>,
......@@ -581,6 +731,11 @@ class LruCache : public Cache {
static_cast<Elem*>(values));
}
void ClearDirtyFlags() override {
// do nothing.
return;
}
void Clear() override { ClearLruCacheContext<Key, Elem>(&ctx_); }
private:
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Inspired by https://github.com/NVIDIA-Merlin/HugeCTR/blob/master/gpu_cache/src/nv_gpu_cache.cu
#include "oneflow/core/embedding/lru_cache.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/embedding/hash_functions.hip.h"
#include <new>
#include <hip/hip_runtime.h>
namespace oneflow {
namespace embedding {
namespace {
constexpr int kWarpSize = 64;
constexpr int kNumWarpPerBlock = 2;
constexpr int kBlockSize = kNumWarpPerBlock * kWarpSize;
constexpr unsigned long long int kFullMask = 0xFFFFFFFFFFFFFFFFU;
ep::CudaLaunchConfig GetLaunchConfig(uint32_t n_keys) {
return ep::CudaLaunchConfig((n_keys + kNumWarpPerBlock - 1) / kNumWarpPerBlock,
kWarpSize * kNumWarpPerBlock, 0);
}
struct ThreadContext {
__device__ ThreadContext() {
const uint32_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;
global_warp_id = global_thread_id / kWarpSize;
warp_id_in_block = global_warp_id % kNumWarpPerBlock; // NOLINT
num_warps = gridDim.x * kNumWarpPerBlock; // NOLINT
lane_id = global_thread_id % kWarpSize;
}
uint32_t global_warp_id;
uint32_t warp_id_in_block;
uint32_t num_warps;
uint32_t lane_id;
};
class WarpMutexAtomicImpl {
public:
OF_DISALLOW_COPY_AND_MOVE(WarpMutexAtomicImpl);
__device__ WarpMutexAtomicImpl() : flag_(0) {}
__device__ ~WarpMutexAtomicImpl() = default;
__device__ void Lock(const ThreadContext& thread_ctx) {
if (thread_ctx.lane_id == 0) {
while (atomicCAS(&flag_, 0, 1) != 0)
;
}
__threadfence();
__syncthreads();
}
__device__ void Unlock(const ThreadContext& thread_ctx) {
__syncthreads();
__threadfence();
if (thread_ctx.lane_id == 0) { atomicExch(&flag_, 0); }
}
private:
int32_t flag_;
};
template<typename Key, typename Elem>
struct LruCacheContext {
Key* keys;
Elem* lines;
uint8_t* ages;
void* mutex;
uint64_t n_set;
uint32_t line_size;
CacheOptions::MemoryKind value_memory_kind;
};
__global__ void InitCacheSetMutex(uint32_t n_set, void* mutex) {
using WarpMutex = WarpMutexAtomicImpl;
const uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n_set) { new (reinterpret_cast<WarpMutex*>(mutex) + idx) WarpMutex; }
}
template<typename Key, typename Elem>
void ClearLruCacheContext(LruCacheContext<Key, Elem>* ctx) {
OF_CUDA_CHECK(hipMemset(ctx->keys, 0, ctx->n_set * kWarpSize * sizeof(Key)));
OF_CUDA_CHECK(hipMemset(ctx->ages, 0, ctx->n_set * kWarpSize * sizeof(uint8_t)));
InitCacheSetMutex<<<(ctx->n_set - 1 + 256) / 256, 256>>>(ctx->n_set, ctx->mutex);
}
template<typename Key, typename Elem>
void InitLruCacheContext(const CacheOptions& options, LruCacheContext<Key, Elem>* ctx) {
const size_t keys_size_per_set = kWarpSize * sizeof(Key);
const uint32_t line_size = options.value_size / sizeof(Elem);
const size_t lines_size_per_set = kWarpSize * line_size * sizeof(Elem);
const size_t ages_size_per_set = kWarpSize * sizeof(uint8_t);
int device = 0;
OF_CUDA_CHECK(hipGetDevice(&device));
int major = 0;
OF_CUDA_CHECK(hipDeviceGetAttribute(&major, hipDeviceAttributeComputeCapabilityMajor, device));
size_t mutex_size_per_set = 0;
mutex_size_per_set = sizeof(WarpMutexAtomicImpl);
const size_t n_set = (options.capacity - 1 + kWarpSize) / kWarpSize;
CHECK_GT(n_set, 0);
ctx->n_set = n_set;
ctx->line_size = line_size;
const size_t keys_size = n_set * keys_size_per_set;
OF_CUDA_CHECK(hipMalloc(&(ctx->keys), keys_size));
const size_t lines_size = n_set * lines_size_per_set;
if (options.value_memory_kind == CacheOptions::MemoryKind::kDevice) {
OF_CUDA_CHECK(hipMalloc(&(ctx->lines), lines_size));
} else if (options.value_memory_kind == CacheOptions::MemoryKind::kHost) {
if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_DISABLE_NUMA_AWARE_ALLOCATION", false)) {
OF_CUDA_CHECK(hipMallocHost(reinterpret_cast<void **>(&(ctx->lines)), lines_size));
} else {
OF_CUDA_CHECK(
NumaAwareCudaMallocHost(device, reinterpret_cast<void**>(&ctx->lines), lines_size));
}
} else {
UNIMPLEMENTED();
}
ctx->value_memory_kind = options.value_memory_kind;
const size_t ages_size = n_set * ages_size_per_set;
OF_CUDA_CHECK(hipMalloc(&(ctx->ages), ages_size));
const size_t mutex_size = n_set * mutex_size_per_set;
OF_CUDA_CHECK(hipMalloc(&(ctx->mutex), mutex_size));
ClearLruCacheContext(ctx);
}
template<typename Key, typename Elem>
void DestroyLruCacheContext(LruCacheContext<Key, Elem>* ctx) {
OF_CUDA_CHECK(hipFree(ctx->keys));
if (ctx->value_memory_kind == CacheOptions::MemoryKind::kDevice) {
OF_CUDA_CHECK(hipFree(ctx->lines));
} else if (ctx->value_memory_kind == CacheOptions::MemoryKind::kHost) {
OF_CUDA_CHECK(hipHostFree(ctx->lines));
} else {
UNIMPLEMENTED();
}
OF_CUDA_CHECK(hipFree(ctx->ages));
OF_CUDA_CHECK(hipFree(ctx->mutex));
}
template<typename Key, typename Elem>
struct SetContext {
using WarpMutex = WarpMutexAtomicImpl;
__device__ SetContext(const LruCacheContext<Key, Elem>& ctx, uint32_t set_id)
: keys(ctx.keys + set_id * kWarpSize),
mutex(reinterpret_cast<WarpMutex*>(ctx.mutex) + set_id),
ages(ctx.ages + set_id * kWarpSize),
lines(ctx.lines + set_id * kWarpSize * ctx.line_size) {}
__device__ int Lookup(const ThreadContext& thread_ctx, Key key) {
const Key lane_key = keys[thread_ctx.lane_id];
const int lane_age = ages[thread_ctx.lane_id];
const bool lane_hit = (lane_key == key && lane_age != 0);
const unsigned long long int hit_mask = __ballot(lane_hit);
if (hit_mask != 0) {
return __ffs(static_cast<int>(hit_mask)) - 1;
} else {
return -1;
}
}
__device__ void Read(const LruCacheContext<Key, Elem>& cache_ctx, const ThreadContext& thread_ctx,
int way, Elem* line) {
const Elem* from_line = lines + way * cache_ctx.line_size;
for (int i = thread_ctx.lane_id; i < cache_ctx.line_size; i += kWarpSize) {
line[i] = from_line[i];
}
}
__device__ int InsertWithoutEvicting(const LruCacheContext<Key, Elem>& cache_ctx,
const ThreadContext& thread_ctx, Key key) {
int insert_way = -1;
const Key lane_key = keys[thread_ctx.lane_id];
int lane_age = ages[thread_ctx.lane_id];
const unsigned long long int hit_mask = __ballot(lane_key == key && lane_age != 0);
if (hit_mask != 0) {
insert_way = __ffs(static_cast<int>(hit_mask)) - 1;
const int insert_way_age = __shfl(lane_age, insert_way);
if (lane_age > insert_way_age) {
lane_age -= 1;
} else if (thread_ctx.lane_id == insert_way) {
lane_age = kWarpSize;
}
__syncthreads();
}
if (insert_way == -1) {
const unsigned long long int valid_mask = __ballot(lane_age != 0);
if (valid_mask != kFullMask) {
insert_way = __popc(static_cast<int>(valid_mask));
if (lane_age > 0) {
lane_age -= 1;
} else if (thread_ctx.lane_id == insert_way) {
lane_age = kWarpSize;
keys[insert_way] = key;
}
__syncthreads();
}
}
if (insert_way != -1) { ages[thread_ctx.lane_id] = lane_age; }
return insert_way;
}
__device__ void Evict(const LruCacheContext<Key, Elem>& cache_ctx,
const ThreadContext& thread_ctx, Key key, int* way, Key* evicted_key) {
const Key lane_key = keys[thread_ctx.lane_id];
int lane_age = ages[thread_ctx.lane_id];
const int insert_way = __ffs(static_cast<int>(__ballot(lane_age == 1))) - 1;
*evicted_key = __shfl(lane_key, insert_way);
if (thread_ctx.lane_id == insert_way) {
keys[insert_way] = key;
lane_age = kWarpSize;
} else if (lane_age > 1) {
lane_age -= 1;
}
__syncthreads();
ages[thread_ctx.lane_id] = lane_age;
*way = insert_way;
}
__device__ void Write(const LruCacheContext<Key, Elem>& cache_ctx,
const ThreadContext& thread_ctx, int way, const Elem* line) {
Elem* to_line = lines + way * cache_ctx.line_size;
for (int i = thread_ctx.lane_id; i < cache_ctx.line_size; i += kWarpSize) {
to_line[i] = line[i];
}
}
__device__ void Lock(const ThreadContext& thread_ctx) { mutex->Lock(thread_ctx); }
__device__ void Unlock(const ThreadContext& thread_ctx) { mutex->Unlock(thread_ctx); }
Key* keys;
Elem* lines;
uint8_t* ages;
WarpMutex* mutex;
};
template<typename Key, typename Elem, bool test_only>
__global__ void GetKernel(LruCacheContext<Key, Elem> cache_ctx, uint32_t num_keys, const Key* keys,
Elem* values, uint32_t* n_missing_keys, Key* missing_keys,
uint32_t* missing_indices) {
ThreadContext thread_ctx{};
__shared__ Key block_keys[kNumWarpPerBlock][kWarpSize];
__shared__ size_t block_set_ids[kNumWarpPerBlock][kWarpSize];
for (uint32_t batch_offset = thread_ctx.global_warp_id * kWarpSize; batch_offset < num_keys;
batch_offset += thread_ctx.num_warps * kWarpSize) {
const uint32_t n_batch_keys = min(kWarpSize, num_keys - batch_offset);
if (thread_ctx.lane_id < n_batch_keys) {
const Key key = keys[batch_offset + thread_ctx.lane_id];
const size_t hash = LruCacheHash()(key);
const uint32_t set_id = hash % cache_ctx.n_set;
block_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = key;
block_set_ids[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = set_id;
}
__syncthreads();
uint32_t n_warp_missing = 0;
Key warp_missing_key = 0;
uint32_t warp_missing_index = 0;
for (uint32_t i = 0; i < n_batch_keys; ++i) {
const uint32_t key_idx = batch_offset + i;
const Key key = block_keys[thread_ctx.warp_id_in_block][i];
const size_t set_id = block_set_ids[thread_ctx.warp_id_in_block][i];
SetContext<Key, Elem> set_ctx(cache_ctx, set_id);
const int way = set_ctx.Lookup(thread_ctx, key);
if (way < 0) {
if (thread_ctx.lane_id == n_warp_missing) {
warp_missing_key = key;
warp_missing_index = key_idx;
}
__syncthreads();
n_warp_missing += 1;
} else if (!test_only) {
set_ctx.Read(cache_ctx, thread_ctx, way, values + key_idx * cache_ctx.line_size);
}
}
if (n_warp_missing > 0) {
uint32_t base_missing_idx = 0;
if (thread_ctx.lane_id == 0) { base_missing_idx = atomicAdd(n_missing_keys, n_warp_missing); }
__syncthreads();
base_missing_idx = __shfl(base_missing_idx, 0);
if (thread_ctx.lane_id < n_warp_missing) {
missing_keys[base_missing_idx + thread_ctx.lane_id] = warp_missing_key;
missing_indices[base_missing_idx + thread_ctx.lane_id] = warp_missing_index;
}
__syncthreads();
}
__syncthreads();
}
}
template<typename Key, typename Elem>
__global__ void PutWithoutEvictingKernel(LruCacheContext<Key, Elem> cache_ctx, uint32_t num_keys,
const Key* keys, const Elem* values, uint32_t* n_missing,
Key* missing_keys, uint32_t* missing_indices) {
ThreadContext thread_ctx{};
__shared__ Key block_keys[kNumWarpPerBlock][kWarpSize];
__shared__ size_t block_set_ids[kNumWarpPerBlock][kWarpSize];
for (uint32_t batch_offset = thread_ctx.global_warp_id * kWarpSize; batch_offset < num_keys;
batch_offset += thread_ctx.num_warps * kWarpSize) {
const uint32_t n_batch_keys = min(kWarpSize, num_keys - batch_offset);
if (thread_ctx.lane_id < n_batch_keys) {
const Key key = keys[batch_offset + thread_ctx.lane_id];
const size_t hash = LruCacheHash()(key);
const uint32_t set_id = hash % cache_ctx.n_set;
block_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = key;
block_set_ids[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = set_id;
}
__syncthreads();
uint32_t n_warp_missing = 0;
Key warp_missing_key = 0;
uint32_t warp_missing_index = 0;
for (uint32_t i = 0; i < n_batch_keys; ++i) {
const uint32_t key_idx = batch_offset + i;
const Key key = block_keys[thread_ctx.warp_id_in_block][i];
const size_t set_id = block_set_ids[thread_ctx.warp_id_in_block][i];
SetContext<Key, Elem> set_ctx(cache_ctx, set_id);
set_ctx.Lock(thread_ctx);
Key evicted_key = 0;
const int insert_way = set_ctx.InsertWithoutEvicting(cache_ctx, thread_ctx, key);
if (insert_way >= 0) {
set_ctx.Write(cache_ctx, thread_ctx, insert_way, values + cache_ctx.line_size * key_idx);
} else {
if (thread_ctx.lane_id == n_warp_missing) {
warp_missing_key = key;
warp_missing_index = key_idx;
}
__syncthreads();
n_warp_missing += 1;
}
set_ctx.Unlock(thread_ctx);
}
if (n_warp_missing > 0) {
uint32_t base_missing_idx = 0;
if (thread_ctx.lane_id == 0) { base_missing_idx = atomicAdd(n_missing, n_warp_missing); }
__syncthreads();
base_missing_idx = __shfl(base_missing_idx, 0);
if (thread_ctx.lane_id < n_warp_missing) {
missing_keys[base_missing_idx + thread_ctx.lane_id] = warp_missing_key;
missing_indices[base_missing_idx + thread_ctx.lane_id] = warp_missing_index;
}
__syncthreads();
}
}
}
template<typename Key, typename Elem>
__global__ void EvictKernel(LruCacheContext<Key, Elem> cache_ctx, const Key* keys,
const uint32_t* indices, const Elem* values, const uint32_t* n_evict,
Key* evicted_keys, Elem* evicted_values) {
ThreadContext thread_ctx{};
uint32_t num_evict = *n_evict;
__shared__ Key block_keys[kNumWarpPerBlock][kWarpSize];
__shared__ size_t block_set_ids[kNumWarpPerBlock][kWarpSize];
for (uint32_t batch_offset = thread_ctx.global_warp_id * kWarpSize; batch_offset < num_evict;
batch_offset += thread_ctx.num_warps * kWarpSize) {
const uint32_t n_batch_keys = min(kWarpSize, num_evict - batch_offset);
if (thread_ctx.lane_id < n_batch_keys) {
const Key key = keys[batch_offset + thread_ctx.lane_id];
const size_t hash = LruCacheHash()(key);
const uint32_t set_id = hash % cache_ctx.n_set;
block_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = key;
block_set_ids[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = set_id;
}
__syncthreads();
for (uint32_t i = 0; i < n_batch_keys; ++i) {
const uint32_t key_idx = batch_offset + i;
const Key key = block_keys[thread_ctx.warp_id_in_block][i];
const uint32_t set_id = block_set_ids[thread_ctx.warp_id_in_block][i];
SetContext<Key, Elem> set_ctx(cache_ctx, set_id);
set_ctx.Lock(thread_ctx);
int evicted_way = -1;
Key evicted_key = 0;
set_ctx.Evict(cache_ctx, thread_ctx, key, &evicted_way, &evicted_key);
if (thread_ctx.lane_id == 0) { evicted_keys[key_idx] = evicted_key; }
__syncthreads();
set_ctx.Read(cache_ctx, thread_ctx, evicted_way,
evicted_values + cache_ctx.line_size * key_idx);
set_ctx.Write(cache_ctx, thread_ctx, evicted_way,
values + cache_ctx.line_size * indices[key_idx]);
set_ctx.Unlock(thread_ctx);
}
}
}
template<typename Key, typename Elem>
__global__ void DumpKernel(LruCacheContext<Key, Elem> cache_ctx, size_t start_key_index,
size_t end_key_index, uint32_t* n_dumped, Key* keys, Elem* values) {
ThreadContext thread_ctx{};
__shared__ Key warp_keys[kNumWarpPerBlock][kWarpSize];
__shared__ uint8_t warp_ages[kNumWarpPerBlock][kWarpSize];
for (uint32_t warp_start_key_index = start_key_index + thread_ctx.global_warp_id * kWarpSize;
warp_start_key_index < end_key_index;
warp_start_key_index += thread_ctx.num_warps * kWarpSize) {
Key lane_key = 0;
uint8_t lane_age = 0;
if (warp_start_key_index + thread_ctx.lane_id < end_key_index) {
lane_key = cache_ctx.keys[warp_start_key_index + thread_ctx.lane_id];
lane_age = cache_ctx.ages[warp_start_key_index + thread_ctx.lane_id];
}
__syncthreads();
warp_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_key;
warp_ages[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_age;
const int key_count = __popc(static_cast<int>(__ballot(lane_age != 0)));
if (key_count == 0) { continue; }
uint32_t offset = 0;
if (thread_ctx.lane_id == 0) { offset = atomicAdd(n_dumped, key_count); }
offset = __shfl(offset, 0);
__syncthreads();
for (uint32_t i = 0; i < kWarpSize; ++i) {
const Key key = warp_keys[thread_ctx.warp_id_in_block][i];
const Key age = warp_ages[thread_ctx.warp_id_in_block][i];
if (age == 0) { continue; }
if (thread_ctx.lane_id == 0) { keys[offset] = key; }
__syncthreads();
for (uint32_t j = thread_ctx.lane_id; j < cache_ctx.line_size; j += kWarpSize) {
values[offset * cache_ctx.line_size + j] =
cache_ctx.lines[(warp_start_key_index + i) * cache_ctx.line_size + j];
}
__syncthreads();
offset += 1;
}
}
}
template<typename Key, typename Elem>
class LruCache : public Cache {
public:
OF_DISALLOW_COPY_AND_MOVE(LruCache);
explicit LruCache(const CacheOptions& options)
: device_index_{},
max_query_length_(0),
query_indices_buffer_(nullptr),
query_keys_buffer_(nullptr),
value_type_(options.value_type) {
OF_CUDA_CHECK(hipGetDevice(&device_index_));
InitLruCacheContext(options, &ctx_);
}
~LruCache() override {
CudaCurrentDeviceGuard guard(device_index_);
if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipFree(query_indices_buffer_));
OF_CUDA_CHECK(hipFree(query_keys_buffer_));
}
DestroyLruCacheContext(&ctx_);
}
uint32_t KeySize() const override { return sizeof(Key); }
uint32_t ValueSize() const override { return sizeof(Elem) * ctx_.line_size; }
DataType ValueType() const override { return value_type_; }
uint64_t Capacity() const override { return ctx_.n_set * kWarpSize; }
uint32_t MaxQueryLength() const override { return max_query_length_; }
void ReserveQueryLength(uint32_t query_length) override {
CudaCurrentDeviceGuard guard(device_index_);
if (query_length < max_query_length_) { return; }
if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipFree(query_indices_buffer_));
OF_CUDA_CHECK(hipFree(query_keys_buffer_));
}
OF_CUDA_CHECK(hipMalloc(&query_indices_buffer_, query_length * sizeof(uint32_t)));
OF_CUDA_CHECK(hipMalloc(&query_keys_buffer_, query_length * sizeof(Key)));
max_query_length_ = query_length;
}
CacheOptions::Policy Policy() const override { return CacheOptions::Policy::kLRU; }
void Test(ep::Stream* stream, uint32_t n_keys, const void* keys, uint32_t* n_missing,
void* missing_keys, uint32_t* missing_indices) override {
CHECK_LE(n_keys, max_query_length_);
auto cuda_stream = stream->As<ep::CudaStream>();
OF_CUDA_CHECK(hipMemsetAsync(n_missing, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));
if (n_keys == 0) { return; }
cuda_stream->LaunchKernel(GetKernel<Key, Elem, true>, GetLaunchConfig(n_keys), ctx_, n_keys,
static_cast<const Key*>(keys), nullptr, n_missing,
static_cast<Key*>(missing_keys), missing_indices);
}
void Get(ep::Stream* stream, uint32_t n_keys, const void* keys, void* values, uint32_t* n_missing,
void* missing_keys, uint32_t* missing_indices) override {
CHECK_LE(n_keys, max_query_length_);
auto cuda_stream = stream->As<ep::CudaStream>();
OF_CUDA_CHECK(hipMemsetAsync(n_missing, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));
if (n_keys == 0) { return; }
cuda_stream->LaunchKernel(GetKernel<Key, Elem, false>, GetLaunchConfig(n_keys), ctx_, n_keys,
static_cast<const Key*>(keys), static_cast<Elem*>(values), n_missing,
static_cast<Key*>(missing_keys), missing_indices);
}
void Put(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values,
uint32_t* n_evicted, void* evicted_keys, void* evicted_values) override {
CHECK_LE(n_keys, max_query_length_);
auto cuda_stream = stream->As<ep::CudaStream>();
OF_CUDA_CHECK(hipMemsetAsync(n_evicted, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));
if (n_keys == 0) { return; }
cuda_stream->LaunchKernel(PutWithoutEvictingKernel<Key, Elem>, GetLaunchConfig(n_keys), ctx_,
n_keys, static_cast<const Key*>(keys),
static_cast<const Elem*>(values), n_evicted, query_keys_buffer_,
query_indices_buffer_);
cuda_stream->LaunchKernel(EvictKernel<Key, Elem>, GetLaunchConfig(n_keys), ctx_,
query_keys_buffer_, query_indices_buffer_,
static_cast<const Elem*>(values), n_evicted,
static_cast<Key*>(evicted_keys), static_cast<Elem*>(evicted_values));
}
void Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index,
uint32_t* n_dumped, void* keys, void* values) override {
auto cuda_stream = stream->As<ep::CudaStream>();
OF_CUDA_CHECK(hipMemsetAsync(n_dumped, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));
const uint64_t max_dump_keys = end_key_index - start_key_index;
cuda_stream->LaunchKernel(
DumpKernel<Key, Elem>,
ep::CudaLaunchConfig((max_dump_keys + kNumWarpPerBlock - 1) / kNumWarpPerBlock, kBlockSize,
0),
ctx_, start_key_index, end_key_index, n_dumped, static_cast<Key*>(keys),
static_cast<Elem*>(values));
}
void Clear() override { ClearLruCacheContext<Key, Elem>(&ctx_); }
private:
int device_index_;
uint32_t max_query_length_;
LruCacheContext<Key, Elem> ctx_;
uint32_t* query_indices_buffer_;
Key* query_keys_buffer_;
DataType value_type_;
};
template<typename Key>
std::unique_ptr<Cache> DispatchValueType(const CacheOptions& options) {
if (options.value_size % sizeof(ulonglong2) == 0) {
return std::unique_ptr<Cache>(new LruCache<Key, ulonglong2>(options));
} else if (options.value_size % sizeof(uint64_t) == 0) {
return std::unique_ptr<Cache>(new LruCache<Key, uint64_t>(options));
} else if (options.value_size % sizeof(uint32_t) == 0) {
return std::unique_ptr<Cache>(new LruCache<Key, uint32_t>(options));
} else if (options.value_size % sizeof(uint16_t) == 0) {
return std::unique_ptr<Cache>(new LruCache<Key, uint16_t>(options));
} else {
return std::unique_ptr<Cache>(new LruCache<Key, uint8_t>(options));
}
}
std::unique_ptr<Cache> DispatchKeyType(const CacheOptions& options) {
if (options.key_size == sizeof(uint32_t)) {
return DispatchValueType<uint32_t>(options);
} else if (options.key_size == sizeof(uint64_t)) {
return DispatchValueType<uint64_t>(options);
} else {
UNIMPLEMENTED();
return nullptr;
}
}
} // namespace
std::unique_ptr<Cache> NewLruCache(const CacheOptions& options) { return DispatchKeyType(options); }
} // namespace embedding
} // namespace oneflow
\ No newline at end of file
......@@ -50,14 +50,14 @@ class IteratorImpl : public KVIterator {
std::memcpy(reinterpret_cast<char*>(host_values_buffer_) + *host_num_buffer_ * value_size_,
pos_->second.data(), value_size_);
}
OF_CUDA_CHECK(cudaMemcpyAsync(n_result, host_num_buffer_, sizeof(uint32_t), cudaMemcpyDefault,
OF_CUDA_CHECK(GPU(MemcpyAsync)(n_result, host_num_buffer_, sizeof(uint32_t), GPU(MemcpyDefault),
cuda_stream->cuda_stream()));
const uint32_t num_keys = *host_num_buffer_;
if (num_keys != 0) {
OF_CUDA_CHECK(cudaMemcpyAsync(keys, host_keys_buffer_, num_keys * key_size_,
cudaMemcpyDefault, cuda_stream->cuda_stream()));
OF_CUDA_CHECK(cudaMemcpyAsync(values, host_values_buffer_, num_keys * value_size_,
cudaMemcpyDefault, cuda_stream->cuda_stream()));
OF_CUDA_CHECK(GPU(MemcpyAsync)(keys, host_keys_buffer_, num_keys * key_size_,
GPU(MemcpyDefault), cuda_stream->cuda_stream()));
OF_CUDA_CHECK(GPU(MemcpyAsync)(values, host_values_buffer_, num_keys * value_size_,
GPU(MemcpyDefault), cuda_stream->cuda_stream()));
}
}
......@@ -80,7 +80,7 @@ class KeyValueStoreImpl : public KeyValueStore {
OF_DISALLOW_COPY_AND_MOVE(KeyValueStoreImpl);
explicit KeyValueStoreImpl(const MockKeyValueStoreOptions& options)
: device_index_(-1), max_query_length_(0) {
OF_CUDA_CHECK(cudaGetDevice(&device_index_));
OF_CUDA_CHECK(GPU(GetDevice)(&device_index_));
key_size_ = options.key_size;
value_size_ = options.value_size;
OF_CUDA_CHECK(NumaAwareCudaMallocHost(
......@@ -97,11 +97,11 @@ class KeyValueStoreImpl : public KeyValueStore {
~KeyValueStoreImpl() {
CudaCurrentDeviceGuard guard(device_index_);
if (max_query_length_ != 0) {
OF_CUDA_CHECK(cudaFreeHost(host_query_keys_));
OF_CUDA_CHECK(cudaFreeHost(host_query_values_));
OF_CUDA_CHECK(cudaFreeHost(host_missing_indices_));
OF_CUDA_CHECK(GPU(FreeHost)(host_query_keys_));
OF_CUDA_CHECK(GPU(FreeHost)(host_query_values_));
OF_CUDA_CHECK(GPU(FreeHost)(host_missing_indices_));
}
OF_CUDA_CHECK(cudaFreeHost(host_n_missing_));
OF_CUDA_CHECK(GPU(FreeHost)(host_n_missing_));
}
uint32_t KeySize() const override { return key_size_; }
......@@ -114,9 +114,9 @@ class KeyValueStoreImpl : public KeyValueStore {
CudaCurrentDeviceGuard guard(device_index_);
if (query_length <= max_query_length_) { return; }
if (max_query_length_ != 0) {
OF_CUDA_CHECK(cudaFreeHost(host_query_keys_));
OF_CUDA_CHECK(cudaFreeHost(host_query_values_));
OF_CUDA_CHECK(cudaFreeHost(host_missing_indices_));
OF_CUDA_CHECK(GPU(FreeHost)(host_query_keys_));
OF_CUDA_CHECK(GPU(FreeHost)(host_query_values_));
OF_CUDA_CHECK(GPU(FreeHost)(host_missing_indices_));
}
OF_CUDA_CHECK(NumaAwareCudaMallocHost(
device_index_, reinterpret_cast<void**>(&host_query_keys_), key_size_ * query_length));
......@@ -128,6 +128,7 @@ class KeyValueStoreImpl : public KeyValueStore {
max_query_length_ = query_length;
}
using KeyValueStore::Get;
void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values,
uint32_t* n_missing, uint32_t* missing_indices) override;
void Put(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values) override;
......@@ -158,11 +159,11 @@ void KeyValueStoreImpl<Key>::Get(ep::Stream* stream, uint32_t num_keys, const vo
auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_LE(num_keys, max_query_length_);
if (num_keys == 0) {
OF_CUDA_CHECK(cudaMemsetAsync(n_missing, 0, sizeof(uint32_t),
OF_CUDA_CHECK(GPU(MemsetAsync)(n_missing, 0, sizeof(uint32_t),
stream->As<ep::CudaStream>()->cuda_stream()));
return;
}
OF_CUDA_CHECK(cudaMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, cudaMemcpyDefault,
OF_CUDA_CHECK(GPU(MemcpyAsync)(host_query_keys_, keys, key_size_ * num_keys, GPU(MemcpyDefault),
cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync());
*host_n_missing_ = 0;
......@@ -175,12 +176,12 @@ void KeyValueStoreImpl<Key>::Get(ep::Stream* stream, uint32_t num_keys, const vo
*host_n_missing_ += 1;
}
}
OF_CUDA_CHECK(cudaMemcpyAsync(values, host_query_values_, num_keys * value_size_,
cudaMemcpyDefault, cuda_stream->cuda_stream()));
OF_CUDA_CHECK(cudaMemcpyAsync(n_missing, host_n_missing_, sizeof(uint32_t), cudaMemcpyDefault,
OF_CUDA_CHECK(GPU(MemcpyAsync)(values, host_query_values_, num_keys * value_size_,
GPU(MemcpyDefault), cuda_stream->cuda_stream()));
OF_CUDA_CHECK(GPU(MemcpyAsync)(n_missing, host_n_missing_, sizeof(uint32_t), GPU(MemcpyDefault),
cuda_stream->cuda_stream()));
OF_CUDA_CHECK(cudaMemcpyAsync(missing_indices, host_missing_indices_,
(*host_n_missing_) * sizeof(uint32_t), cudaMemcpyDefault,
OF_CUDA_CHECK(GPU(MemcpyAsync)(missing_indices, host_missing_indices_,
(*host_n_missing_) * sizeof(uint32_t), GPU(MemcpyDefault),
cuda_stream->cuda_stream()));
}
......@@ -191,10 +192,10 @@ void KeyValueStoreImpl<Key>::Put(ep::Stream* stream, uint32_t num_keys, const vo
auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_LE(num_keys, max_query_length_);
if (num_keys == 0) { return; }
OF_CUDA_CHECK(cudaMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, cudaMemcpyDefault,
OF_CUDA_CHECK(GPU(MemcpyAsync)(host_query_keys_, keys, key_size_ * num_keys, GPU(MemcpyDefault),
cuda_stream->cuda_stream()));
OF_CUDA_CHECK(cudaMemcpyAsync(host_query_values_, values, value_size_ * num_keys,
cudaMemcpyDefault, cuda_stream->cuda_stream()));
OF_CUDA_CHECK(GPU(MemcpyAsync)(host_query_values_, values, value_size_ * num_keys,
GPU(MemcpyDefault), cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync());
for (uint32_t i = 0; i < num_keys; ++i) {
store_[host_query_keys_[i]] = std::string(
......
......@@ -22,17 +22,7 @@ namespace oneflow {
namespace embedding {
#ifdef WITH_CUDA
struct MockKeyValueStoreOptions {
uint32_t key_size = 0;
uint32_t value_size = 0;
};
std::unique_ptr<KeyValueStore> NewMockKeyValueStore(const MockKeyValueStoreOptions& options);
#endif // WITH_CUDA
#ifdef WITH_ROCM
#if defined(WITH_CUDA) || defined(WITH_ROCM)
struct MockKeyValueStoreOptions {
uint32_t key_size = 0;
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/embedding/mock_key_value_store.h"
#include "oneflow/core/device/cuda_util.h"
namespace oneflow {
namespace embedding {
namespace {
template<typename Key>
class IteratorImpl : public KVIterator {
public:
OF_DISALLOW_COPY_AND_MOVE(IteratorImpl);
IteratorImpl(HashMap<Key, std::string>* store, uint32_t key_size, uint32_t value_size,
uint32_t max_query_length, void* host_keys_buffer, void* host_values_buffer,
uint32_t* host_num_buffer)
: store_(store),
pos_(store->begin()),
key_size_(key_size),
value_size_(value_size),
max_query_length_(max_query_length),
host_keys_buffer_(host_keys_buffer),
host_values_buffer_(host_values_buffer),
host_num_buffer_(host_num_buffer) {}
~IteratorImpl() override = default;
void NextN(ep::Stream* stream, uint32_t n_request, uint32_t* n_result, void* keys,
void* values) override {
CHECK_LE(n_request, max_query_length_);
auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_JUST(cuda_stream->Sync());
*host_num_buffer_ = 0;
while (*host_num_buffer_ < n_request && pos_ != store_->end()) {
reinterpret_cast<Key*>(host_keys_buffer_)[*host_num_buffer_] = pos_->first;
std::memcpy(reinterpret_cast<char*>(host_values_buffer_) + *host_num_buffer_ * value_size_,
pos_->second.data(), value_size_);
}
OF_CUDA_CHECK(hipMemcpyAsync(n_result, host_num_buffer_, sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream()));
const uint32_t num_keys = *host_num_buffer_;
if (num_keys != 0) {
OF_CUDA_CHECK(hipMemcpyAsync(keys, host_keys_buffer_, num_keys * key_size_,
hipMemcpyDefault, cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(values, host_values_buffer_, num_keys * value_size_,
hipMemcpyDefault, cuda_stream->cuda_stream()));
}
}
void Reset() override { pos_ = store_->begin(); }
private:
HashMap<Key, std::string>* store_;
typename HashMap<Key, std::string>::iterator pos_;
uint32_t key_size_;
uint32_t value_size_;
uint32_t max_query_length_;
void* host_keys_buffer_;
void* host_values_buffer_;
uint32_t* host_num_buffer_;
};
template<typename Key>
class KeyValueStoreImpl : public KeyValueStore {
public:
OF_DISALLOW_COPY_AND_MOVE(KeyValueStoreImpl);
explicit KeyValueStoreImpl(const MockKeyValueStoreOptions& options)
: device_index_(-1), max_query_length_(0) {
OF_CUDA_CHECK(hipGetDevice(&device_index_));
key_size_ = options.key_size;
value_size_ = options.value_size;
OF_CUDA_CHECK(NumaAwareCudaMallocHost(
device_index_, reinterpret_cast<void**>(&host_query_keys_), key_size_ * max_query_length_));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,
reinterpret_cast<void**>(&host_query_values_),
value_size_ * max_query_length_));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, reinterpret_cast<void**>(&host_n_missing_),
sizeof(uint32_t)));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,
reinterpret_cast<void**>(&host_missing_indices_),
sizeof(uint32_t) * max_query_length_));
}
~KeyValueStoreImpl() {
CudaCurrentDeviceGuard guard(device_index_);
if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipHostFree(host_query_keys_));
OF_CUDA_CHECK(hipHostFree(host_query_values_));
OF_CUDA_CHECK(hipHostFree(host_missing_indices_));
}
OF_CUDA_CHECK(hipHostFree(host_n_missing_));
}
uint32_t KeySize() const override { return key_size_; }
uint32_t ValueSize() const override { return value_size_; }
uint32_t MaxQueryLength() const override { return max_query_length_; }
void ReserveQueryLength(uint32_t query_length) override {
CudaCurrentDeviceGuard guard(device_index_);
if (query_length <= max_query_length_) { return; }
if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipHostFree(host_query_keys_));
OF_CUDA_CHECK(hipHostFree(host_query_values_));
OF_CUDA_CHECK(hipHostFree(host_missing_indices_));
}
OF_CUDA_CHECK(NumaAwareCudaMallocHost(
device_index_, reinterpret_cast<void**>(&host_query_keys_), key_size_ * query_length));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(
device_index_, reinterpret_cast<void**>(&host_query_values_), value_size_ * query_length));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,
reinterpret_cast<void**>(&host_missing_indices_),
sizeof(uint32_t) * query_length));
max_query_length_ = query_length;
}
void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values,
uint32_t* n_missing, uint32_t* missing_indices) override;
void Put(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values) override;
bool SnapshotExists(const std::string& name) override;
void LoadSnapshot(const std::string& name) override;
void LoadSnapshot(const std::string& name,
const std::function<void(KVIterator* iter)>& Hook) override;
void SaveSnapshot(const std::string& name) override;
private:
int device_index_;
uint32_t max_query_length_;
uint32_t key_size_;
uint32_t value_size_;
Key* host_query_keys_{};
uint8_t* host_query_values_{};
uint32_t* host_n_missing_{};
uint32_t* host_missing_indices_{};
HashMap<Key, std::string> store_;
HashMap<std::string, HashMap<Key, std::string>> snapshots_;
std::mutex mutex_;
};
template<typename Key>
void KeyValueStoreImpl<Key>::Get(ep::Stream* stream, uint32_t num_keys, const void* keys,
void* values, uint32_t* n_missing, uint32_t* missing_indices) {
std::lock_guard<std::mutex> lock(mutex_);
auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_LE(num_keys, max_query_length_);
if (num_keys == 0) {
OF_CUDA_CHECK(hipMemsetAsync(n_missing, 0, sizeof(uint32_t),
stream->As<ep::CudaStream>()->cuda_stream()));
return;
}
OF_CUDA_CHECK(hipMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, hipMemcpyDefault,
cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync());
*host_n_missing_ = 0;
for (uint32_t i = 0; i < num_keys; ++i) {
auto it = store_.find(host_query_keys_[i]);
if (it != store_.end()) {
std::memcpy(host_query_values_ + i * value_size_, it->second.data(), value_size_);
} else {
host_missing_indices_[*host_n_missing_] = i;
*host_n_missing_ += 1;
}
}
OF_CUDA_CHECK(hipMemcpyAsync(values, host_query_values_, num_keys * value_size_,
hipMemcpyDefault, cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(n_missing, host_n_missing_, sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(missing_indices, host_missing_indices_,
(*host_n_missing_) * sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream()));
}
template<typename Key>
void KeyValueStoreImpl<Key>::Put(ep::Stream* stream, uint32_t num_keys, const void* keys,
const void* values) {
std::lock_guard<std::mutex> lock(mutex_);
auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_LE(num_keys, max_query_length_);
if (num_keys == 0) { return; }
OF_CUDA_CHECK(hipMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, hipMemcpyDefault,
cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(host_query_values_, values, value_size_ * num_keys,
hipMemcpyDefault, cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync());
for (uint32_t i = 0; i < num_keys; ++i) {
store_[host_query_keys_[i]] = std::string(
reinterpret_cast<const char*>(host_query_values_) + i * value_size_, value_size_);
}
}
template<typename Key>
bool KeyValueStoreImpl<Key>::SnapshotExists(const std::string& name) {
return snapshots_.find(name) != snapshots_.end();
}
template<typename Key>
void KeyValueStoreImpl<Key>::LoadSnapshot(const std::string& name) {
CudaCurrentDeviceGuard guard(device_index_);
LoadSnapshot(name, nullptr);
}
template<typename Key>
void KeyValueStoreImpl<Key>::LoadSnapshot(const std::string& name,
const std::function<void(KVIterator* iter)>& Hook) {
CudaCurrentDeviceGuard guard(device_index_);
store_ = snapshots_[name];
if (Hook) {
IteratorImpl<Key> iterator(&store_, KeySize(), ValueSize(), max_query_length_, host_query_keys_,
host_query_values_, host_n_missing_);
Hook(&iterator);
}
}
template<typename Key>
void KeyValueStoreImpl<Key>::SaveSnapshot(const std::string& name) {
CudaCurrentDeviceGuard guard(device_index_);
snapshots_[name] = store_;
}
} // namespace
std::unique_ptr<KeyValueStore> NewMockKeyValueStore(const MockKeyValueStoreOptions& options) {
if (options.key_size == sizeof(uint64_t)) {
return std::unique_ptr<KeyValueStore>(new KeyValueStoreImpl<uint64_t>(options));
} else if (options.key_size == sizeof(uint32_t)) {
return std::unique_ptr<KeyValueStore>(new KeyValueStoreImpl<uint32_t>(options));
} else {
UNIMPLEMENTED();
return nullptr;
}
}
} // namespace embedding
} // namespace oneflow
\ No newline at end of file
......@@ -395,6 +395,7 @@ class PersistentTableImpl : public PersistentTable {
PosixFile writable_key_file_;
uint64_t writable_key_file_chunk_id_;
PosixFileLockGuard lock_;
bool read_only_;
};
template<typename Key, typename Engine>
......@@ -405,14 +406,19 @@ PersistentTableImpl<Key, Engine>::PersistentTableImpl(const PersistentTableOptio
physical_block_size_(options.physical_block_size),
logical_block_size_(GetLogicalBlockSize(options.physical_block_size, value_size_)),
blocks_buffer_(options.physical_block_size),
writable_key_file_chunk_id_(-1) {
writable_key_file_chunk_id_(-1),
read_only_(options.read_only) {
const uint64_t capacity_hint = ParseIntegerFromEnv(
"ONEFLOW_ONE_EMBEDDING_PERSISTENT_TABLE_CAPACITY_HINT", options.capacity_hint);
if (capacity_hint > 0) { row_id_mapping_.reserve(capacity_hint); }
PosixFile::RecursiveCreateDirectory(options.path, 0755);
const std::string lock_filename = PosixFile::JoinPath(options.path, kLockFileName);
const bool init = !PosixFile::FileExists(lock_filename);
lock_ = PosixFileLockGuard(PosixFile(lock_filename, O_CREAT | O_RDWR, 0644));
if (read_only_) {
CHECK(!init) << "The table must be initialized in read only mode";
} else {
lock_ = PosixFileLockGuard(PosixFile(lock_filename, O_CREAT | O_RDWR, 0644));
}
const uint64_t target_chunk_size = options.target_chunk_size_mb * 1024 * 1024;
CHECK_GE(target_chunk_size, logical_block_size_);
num_logical_blocks_per_chunk_ = target_chunk_size / logical_block_size_,
......@@ -442,7 +448,8 @@ PersistentTableImpl<Key, Engine>::PersistentTableImpl(const PersistentTableOptio
for (auto& chunk : chunks) {
if (value_files_.size() <= chunk.first) { value_files_.resize(chunk.first + 1); }
CHECK_EQ(value_files_.at(chunk.first).fd(), -1);
PosixFile value_file(chunk.second, O_RDWR | O_DIRECT, 0644);
const int flags = read_only_ ? (O_RDONLY | O_DIRECT) : (O_RDWR | O_DIRECT);
PosixFile value_file(chunk.second, flags, 0644);
value_files_.at(chunk.first) = std::move(value_file);
}
if (!value_files_.empty()) {
......@@ -523,6 +530,7 @@ void PersistentTableImpl<Key, Engine>::Get(uint32_t num_keys, const void* keys,
template<typename Key, typename Engine>
void PersistentTableImpl<Key, Engine>::PutBlocks(uint32_t num_keys, const void* keys,
const void* blocks) {
CHECK(!read_only_);
std::lock_guard<std::recursive_mutex> lock(mutex_);
const uint32_t num_blocks = RoundUp(num_keys, num_values_per_block_) / num_values_per_block_;
const uint32_t num_padded_keys = num_blocks * num_values_per_block_;
......@@ -579,6 +587,7 @@ void PersistentTableImpl<Key, Engine>::PutBlocks(uint32_t num_keys, const void*
template<typename Key, typename Engine>
void PersistentTableImpl<Key, Engine>::Put(uint32_t num_keys, const void* keys,
const void* values) {
CHECK(!read_only_);
std::lock_guard<std::recursive_mutex> lock(mutex_);
const void* blocks_ptr = nullptr;
if (value_size_ == logical_block_size_
......@@ -656,6 +665,7 @@ void PersistentTableImpl<Key, Engine>::LoadSnapshotImpl(const std::string& name)
template<typename Key, typename Engine>
void PersistentTableImpl<Key, Engine>::SaveSnapshotImpl(const std::string& name) {
CHECK(!read_only_);
std::lock_guard<std::recursive_mutex> lock(mutex_);
PosixFile::RecursiveCreateDirectory(SnapshotDirPath(name), 0755);
std::ofstream list_ofs(SnapshotListFilePath(name));
......@@ -704,13 +714,11 @@ template<typename Key, typename Engine>
void PersistentTableImpl<Key, Engine>::LoadSnapshot(
const std::string& name, const std::function<void(Iterator* iter)>& Hook) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
int mmap_flags = MAP_SHARED;
if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_PERSISTENT_TABLE_SNAPSHOT_LOAD_MAP_POPULATE",
true)) {
mmap_flags |= MAP_POPULATE;
}
const std::string snapshot_base = SnapshotDirPath(name);
const std::string snapshot_list = SnapshotListFilePath(name);
row_id_mapping_.clear();
......@@ -723,10 +731,8 @@ void PersistentTableImpl<Key, Engine>::LoadSnapshot(
CHECK_EQ(index_file_size % sizeof(uint64_t), 0);
if (index_file_size == 0) { return; }
const size_t n_entries = index_file_size / sizeof(uint64_t);
// PosixMappedFile mapped_index(std::move(index_file), index_file_size, PROT_READ);
PosixMappedFile mapped_index(std::move(index_file), index_file_size, PROT_READ, mmap_flags);
PosixFile key_file(KeyFilePath(chunk_id), O_RDONLY, 0644);
// PosixMappedFile mapped_key(std::move(key_file), key_file.Size(), PROT_READ);
PosixMappedFile mapped_key(std::move(key_file), key_file.Size(), PROT_READ, mmap_flags);
const uint64_t* indices = static_cast<const uint64_t*>(mapped_index.ptr());
const Key* keys = static_cast<const Key*>(mapped_key.ptr());
......@@ -737,7 +743,6 @@ void PersistentTableImpl<Key, Engine>::LoadSnapshot(
}
if (Hook) {
PosixFile value_file(ValueFilePath(chunk_id), O_RDONLY, 0644);
// PosixMappedFile mapped_value(std::move(value_file), value_file.Size(), PROT_READ);
PosixMappedFile mapped_value(std::move(value_file), value_file.Size(), PROT_READ, mmap_flags);
ChunkIteratorImpl<Key> chunk_iterator(value_size_, logical_block_size_, num_values_per_block_,
num_values_per_chunk_, chunk_id, n_entries, keys,
......
......@@ -29,6 +29,7 @@ struct PersistentTableOptions {
uint64_t target_chunk_size_mb = 4 * 1024;
uint16_t physical_block_size = 4096;
uint64_t capacity_hint = 0;
bool read_only = false;
};
class PersistentTable {
......
......@@ -49,14 +49,14 @@ class IteratorImpl : public KVIterator {
auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_JUST(cuda_stream->Sync());
base_iter_->Next(n_request, host_num_buffer_, host_keys_buffer_, host_values_buffer_);
OF_CUDA_CHECK(cudaMemcpyAsync(n_result, host_num_buffer_, sizeof(uint32_t), cudaMemcpyDefault,
OF_CUDA_CHECK(GPU(MemcpyAsync)(n_result, host_num_buffer_, sizeof(uint32_t), GPU(MemcpyDefault),
cuda_stream->cuda_stream()));
const uint32_t num_keys = *host_num_buffer_;
if (num_keys != 0) {
OF_CUDA_CHECK(cudaMemcpyAsync(keys, host_keys_buffer_, num_keys * key_size_,
cudaMemcpyDefault, cuda_stream->cuda_stream()));
OF_CUDA_CHECK(cudaMemcpyAsync(values, host_values_buffer_, num_keys * value_size_,
cudaMemcpyDefault, cuda_stream->cuda_stream()));
OF_CUDA_CHECK(GPU(MemcpyAsync)(keys, host_keys_buffer_, num_keys * key_size_,
GPU(MemcpyDefault), cuda_stream->cuda_stream()));
OF_CUDA_CHECK(GPU(MemcpyAsync)(values, host_values_buffer_, num_keys * value_size_,
GPU(MemcpyDefault), cuda_stream->cuda_stream()));
}
}
......@@ -78,7 +78,7 @@ class KeyValueStoreImpl : public KeyValueStore {
OF_DISALLOW_COPY_AND_MOVE(KeyValueStoreImpl);
explicit KeyValueStoreImpl(const PersistentTableKeyValueStoreOptions& options)
: device_index_(-1), max_query_length_(0) {
OF_CUDA_CHECK(cudaGetDevice(&device_index_));
OF_CUDA_CHECK(GPU(GetDevice)(&device_index_));
key_size_ = options.table_options.key_size;
value_size_ = options.table_options.value_size;
table_ = NewPersistentTable(options.table_options);
......@@ -96,11 +96,11 @@ class KeyValueStoreImpl : public KeyValueStore {
~KeyValueStoreImpl() {
CudaCurrentDeviceGuard guard(device_index_);
if (max_query_length_ != 0) {
OF_CUDA_CHECK(cudaFreeHost(host_query_keys_));
OF_CUDA_CHECK(cudaFreeHost(host_query_values_));
OF_CUDA_CHECK(cudaFreeHost(host_missing_indices_));
OF_CUDA_CHECK(GPU(FreeHost)(host_query_keys_));
OF_CUDA_CHECK(GPU(FreeHost)(host_query_values_));
OF_CUDA_CHECK(GPU(FreeHost)(host_missing_indices_));
}
OF_CUDA_CHECK(cudaFreeHost(host_n_missing_));
OF_CUDA_CHECK(GPU(FreeHost)(host_n_missing_));
}
uint32_t KeySize() const override { return key_size_; }
......@@ -113,9 +113,9 @@ class KeyValueStoreImpl : public KeyValueStore {
CudaCurrentDeviceGuard guard(device_index_);
if (query_length <= max_query_length_) { return; }
if (max_query_length_ != 0) {
OF_CUDA_CHECK(cudaFreeHost(host_query_keys_));
OF_CUDA_CHECK(cudaFreeHost(host_query_values_));
OF_CUDA_CHECK(cudaFreeHost(host_missing_indices_));
OF_CUDA_CHECK(GPU(FreeHost)(host_query_keys_));
OF_CUDA_CHECK(GPU(FreeHost)(host_query_values_));
OF_CUDA_CHECK(GPU(FreeHost)(host_missing_indices_));
}
OF_CUDA_CHECK(NumaAwareCudaMallocHost(
device_index_, reinterpret_cast<void**>(&host_query_keys_), key_size_ * query_length));
......@@ -127,6 +127,7 @@ class KeyValueStoreImpl : public KeyValueStore {
max_query_length_ = query_length;
}
using KeyValueStore::Get;
void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values,
uint32_t* n_missing, uint32_t* missing_indices) override;
void Put(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values) override;
......@@ -157,23 +158,23 @@ void KeyValueStoreImpl<Key>::Get(ep::Stream* stream, uint32_t num_keys, const vo
auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_LE(num_keys, max_query_length_);
if (num_keys == 0) {
OF_CUDA_CHECK(cudaMemsetAsync(n_missing, 0, sizeof(uint32_t),
OF_CUDA_CHECK(GPU(MemsetAsync)(n_missing, 0, sizeof(uint32_t),
stream->As<ep::CudaStream>()->cuda_stream()));
return;
}
OF_CUDA_CHECK(cudaMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, cudaMemcpyDefault,
OF_CUDA_CHECK(GPU(MemcpyAsync)(host_query_keys_, keys, key_size_ * num_keys, GPU(MemcpyDefault),
cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync());
table_->Get(num_keys, host_query_keys_, host_query_values_, host_n_missing_,
host_missing_indices_);
OF_CUDA_CHECK(cudaMemcpyAsync(values, host_query_values_, num_keys * value_size_,
cudaMemcpyDefault, cuda_stream->cuda_stream()));
OF_CUDA_CHECK(cudaMemcpyAsync(n_missing, host_n_missing_, sizeof(uint32_t), cudaMemcpyDefault,
OF_CUDA_CHECK(GPU(MemcpyAsync)(values, host_query_values_, num_keys * value_size_,
GPU(MemcpyDefault), cuda_stream->cuda_stream()));
OF_CUDA_CHECK(GPU(MemcpyAsync)(n_missing, host_n_missing_, sizeof(uint32_t), GPU(MemcpyDefault),
cuda_stream->cuda_stream()));
OF_CUDA_CHECK(cudaMemcpyAsync(missing_indices, host_missing_indices_,
(*host_n_missing_) * sizeof(uint32_t), cudaMemcpyDefault,
OF_CUDA_CHECK(GPU(MemcpyAsync)(missing_indices, host_missing_indices_,
(*host_n_missing_) * sizeof(uint32_t), GPU(MemcpyDefault),
cuda_stream->cuda_stream()));
}
......@@ -184,10 +185,10 @@ void KeyValueStoreImpl<Key>::Put(ep::Stream* stream, uint32_t num_keys, const vo
auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_LE(num_keys, max_query_length_);
if (num_keys == 0) { return; }
OF_CUDA_CHECK(cudaMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, cudaMemcpyDefault,
OF_CUDA_CHECK(GPU(MemcpyAsync)(host_query_keys_, keys, key_size_ * num_keys, GPU(MemcpyDefault),
cuda_stream->cuda_stream()));
OF_CUDA_CHECK(cudaMemcpyAsync(host_query_values_, values, value_size_ * num_keys,
cudaMemcpyDefault, cuda_stream->cuda_stream()));
OF_CUDA_CHECK(GPU(MemcpyAsync)(host_query_values_, values, value_size_ * num_keys,
GPU(MemcpyDefault), cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync());
table_->Put(num_keys, host_query_keys_, host_query_values_);
}
......
......@@ -23,7 +23,7 @@ namespace oneflow {
namespace embedding {
#ifdef WITH_CUDA
#if defined(WITH_CUDA) || defined(WITH_ROCM)
struct PersistentTableKeyValueStoreOptions {
PersistentTableOptions table_options{};
......@@ -33,16 +33,6 @@ std::unique_ptr<KeyValueStore> NewPersistentTableKeyValueStore(
const PersistentTableKeyValueStoreOptions& options);
#endif // WITH_CUDA
#ifdef WITH_ROCM
struct PersistentTableKeyValueStoreOptions {
PersistentTableOptions table_options{};
};
std::unique_ptr<KeyValueStore> NewPersistentTableKeyValueStore(
const PersistentTableKeyValueStoreOptions& options);
#endif // WITH_ROCM
} // namespace embedding
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/embedding/persistent_table_key_value_store.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/embedding/persistent_table.h"
#include <robin_hood.h>
#include <fcntl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <dirent.h>
namespace oneflow {
namespace embedding {
namespace {
class IteratorImpl : public KVIterator {
public:
OF_DISALLOW_COPY_AND_MOVE(IteratorImpl);
IteratorImpl(PersistentTable::Iterator* base_iter, uint32_t key_size, uint32_t value_size,
uint32_t max_query_length, void* host_keys_buffer, void* host_values_buffer,
uint32_t* host_num_buffer)
: base_iter_(base_iter),
key_size_(key_size),
value_size_(value_size),
max_query_length_(max_query_length),
host_keys_buffer_(host_keys_buffer),
host_values_buffer_(host_values_buffer),
host_num_buffer_(host_num_buffer) {}
~IteratorImpl() override = default;
void NextN(ep::Stream* stream, uint32_t n_request, uint32_t* n_result, void* keys,
void* values) override {
CHECK_LE(n_request, max_query_length_);
auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_JUST(cuda_stream->Sync());
base_iter_->Next(n_request, host_num_buffer_, host_keys_buffer_, host_values_buffer_);
OF_CUDA_CHECK(hipMemcpyAsync(n_result, host_num_buffer_, sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream()));
const uint32_t num_keys = *host_num_buffer_;
if (num_keys != 0) {
OF_CUDA_CHECK(hipMemcpyAsync(keys, host_keys_buffer_, num_keys * key_size_,
hipMemcpyDefault, cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(values, host_values_buffer_, num_keys * value_size_,
hipMemcpyDefault, cuda_stream->cuda_stream()));
}
}
void Reset() override { base_iter_->Reset(); }
private:
PersistentTable::Iterator* base_iter_;
uint32_t key_size_;
uint32_t value_size_;
uint32_t max_query_length_;
void* host_keys_buffer_;
void* host_values_buffer_;
uint32_t* host_num_buffer_;
};
template<typename Key>
class KeyValueStoreImpl : public KeyValueStore {
public:
OF_DISALLOW_COPY_AND_MOVE(KeyValueStoreImpl);
explicit KeyValueStoreImpl(const PersistentTableKeyValueStoreOptions& options)
: device_index_(-1), max_query_length_(0) {
OF_CUDA_CHECK(hipGetDevice(&device_index_));
key_size_ = options.table_options.key_size;
value_size_ = options.table_options.value_size;
table_ = NewPersistentTable(options.table_options);
OF_CUDA_CHECK(NumaAwareCudaMallocHost(
device_index_, reinterpret_cast<void**>(&host_query_keys_), key_size_ * max_query_length_));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,
reinterpret_cast<void**>(&host_query_values_),
value_size_ * max_query_length_));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_, reinterpret_cast<void**>(&host_n_missing_),
sizeof(uint32_t)));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,
reinterpret_cast<void**>(&host_missing_indices_),
sizeof(uint32_t) * max_query_length_));
}
~KeyValueStoreImpl() {
CudaCurrentDeviceGuard guard(device_index_);
if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipHostFree(host_query_keys_));
OF_CUDA_CHECK(hipHostFree(host_query_values_));
OF_CUDA_CHECK(hipHostFree(host_missing_indices_));
}
OF_CUDA_CHECK(hipHostFree(host_n_missing_));
}
uint32_t KeySize() const override { return key_size_; }
uint32_t ValueSize() const override { return value_size_; }
uint32_t MaxQueryLength() const override { return max_query_length_; }
void ReserveQueryLength(uint32_t query_length) override {
CudaCurrentDeviceGuard guard(device_index_);
if (query_length <= max_query_length_) { return; }
if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipHostFree(host_query_keys_));
OF_CUDA_CHECK(hipHostFree(host_query_values_));
OF_CUDA_CHECK(hipHostFree(host_missing_indices_));
}
OF_CUDA_CHECK(NumaAwareCudaMallocHost(
device_index_, reinterpret_cast<void**>(&host_query_keys_), key_size_ * query_length));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(
device_index_, reinterpret_cast<void**>(&host_query_values_), value_size_ * query_length));
OF_CUDA_CHECK(NumaAwareCudaMallocHost(device_index_,
reinterpret_cast<void**>(&host_missing_indices_),
sizeof(uint32_t) * query_length));
max_query_length_ = query_length;
}
void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values,
uint32_t* n_missing, uint32_t* missing_indices) override;
void Put(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values) override;
bool SnapshotExists(const std::string& name) override;
void LoadSnapshot(const std::string& name) override;
void LoadSnapshot(const std::string& name,
const std::function<void(KVIterator* iter)>& Hook) override;
void SaveSnapshot(const std::string& name) override;
private:
int device_index_;
uint32_t max_query_length_;
uint32_t key_size_;
uint32_t value_size_;
Key* host_query_keys_{};
uint8_t* host_query_values_{};
uint32_t* host_n_missing_{};
uint32_t* host_missing_indices_{};
std::mutex mutex_;
std::unique_ptr<PersistentTable> table_;
};
template<typename Key>
void KeyValueStoreImpl<Key>::Get(ep::Stream* stream, uint32_t num_keys, const void* keys,
void* values, uint32_t* n_missing, uint32_t* missing_indices) {
std::lock_guard<std::mutex> lock(mutex_);
auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_LE(num_keys, max_query_length_);
if (num_keys == 0) {
OF_CUDA_CHECK(hipMemsetAsync(n_missing, 0, sizeof(uint32_t),
stream->As<ep::CudaStream>()->cuda_stream()));
return;
}
OF_CUDA_CHECK(hipMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, hipMemcpyDefault,
cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync());
table_->Get(num_keys, host_query_keys_, host_query_values_, host_n_missing_,
host_missing_indices_);
OF_CUDA_CHECK(hipMemcpyAsync(values, host_query_values_, num_keys * value_size_,
hipMemcpyDefault, cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(n_missing, host_n_missing_, sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(missing_indices, host_missing_indices_,
(*host_n_missing_) * sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream()));
}
template<typename Key>
void KeyValueStoreImpl<Key>::Put(ep::Stream* stream, uint32_t num_keys, const void* keys,
const void* values) {
std::lock_guard<std::mutex> lock(mutex_);
auto cuda_stream = stream->As<ep::CudaStream>();
CHECK_LE(num_keys, max_query_length_);
if (num_keys == 0) { return; }
OF_CUDA_CHECK(hipMemcpyAsync(host_query_keys_, keys, key_size_ * num_keys, hipMemcpyDefault,
cuda_stream->cuda_stream()));
OF_CUDA_CHECK(hipMemcpyAsync(host_query_values_, values, value_size_ * num_keys,
hipMemcpyDefault, cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync());
table_->Put(num_keys, host_query_keys_, host_query_values_);
}
template<typename Key>
bool KeyValueStoreImpl<Key>::SnapshotExists(const std::string& name) {
return table_->SnapshotExists(name);
}
template<typename Key>
void KeyValueStoreImpl<Key>::LoadSnapshot(const std::string& name) {
CudaCurrentDeviceGuard guard(device_index_);
LoadSnapshot(name, nullptr);
}
template<typename Key>
void KeyValueStoreImpl<Key>::LoadSnapshot(const std::string& name,
const std::function<void(KVIterator* iter)>& Hook) {
CudaCurrentDeviceGuard guard(device_index_);
if (Hook) {
table_->LoadSnapshot(name, [&](PersistentTable::Iterator* chunk_iterator) {
IteratorImpl iterator(chunk_iterator, KeySize(), ValueSize(), max_query_length_,
host_query_keys_, host_query_values_, host_n_missing_);
Hook(&iterator);
});
} else {
table_->LoadSnapshot(name);
}
}
template<typename Key>
void KeyValueStoreImpl<Key>::SaveSnapshot(const std::string& name) {
CudaCurrentDeviceGuard guard(device_index_);
table_->SaveSnapshot(name);
}
} // namespace
std::unique_ptr<KeyValueStore> NewPersistentTableKeyValueStore(
const PersistentTableKeyValueStoreOptions& options) {
if (options.table_options.key_size == sizeof(uint64_t)) {
return std::unique_ptr<KeyValueStore>(new KeyValueStoreImpl<uint64_t>(options));
} else if (options.table_options.key_size == sizeof(uint32_t)) {
return std::unique_ptr<KeyValueStore>(new KeyValueStoreImpl<uint32_t>(options));
} else {
UNIMPLEMENTED();
return nullptr;
}
}
} // namespace embedding
} // namespace oneflow
\ No newline at end of file
......@@ -141,15 +141,15 @@ class PosixFile final {
class PosixMappedFile final {
public:
PosixMappedFile() : file_(), ptr_(nullptr) {}
// PosixMappedFile(PosixFile&& file, size_t size, int prot) : file_(std::move(file)), ptr_(nullptr) {
PosixMappedFile(PosixFile&& file, size_t size, int prot, int flags) : file_(std::move(file)), ptr_(nullptr) {
PosixMappedFile(PosixFile&& file, size_t size, int prot, int flags)
: file_(std::move(file)), ptr_(nullptr) {
CHECK_NE(file_.fd(), -1);
// void* ptr = mmap(nullptr, size, prot, MAP_SHARED, file_.fd(), 0);
void* ptr = mmap(nullptr, size, prot, flags, file_.fd(), 0);
PCHECK(ptr != MAP_FAILED);
ptr_ = ptr;
}
PosixMappedFile(PosixFile&& file, size_t size, int prot) : PosixMappedFile(std::move(file), size, prot, MAP_SHARED) {}
PosixMappedFile(PosixFile&& file, size_t size, int prot)
: PosixMappedFile(std::move(file), size, prot, MAP_SHARED) {}
PosixMappedFile(PosixMappedFile&& other) noexcept : PosixMappedFile() {
*this = std::move(other);
}
......
......@@ -93,7 +93,6 @@ REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, BatchMatmulFactory,
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, BatchMatmulFactory,
BatchMatmulFactoryImpl<DeviceType::kCUDA>);
#endif // WITH_CUDA
#ifdef WITH_ROCM
REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, BatchMatmulFactory,
BatchMatmulFactoryImpl<DeviceType::kCUDA>);
......
......@@ -19,6 +19,7 @@ limitations under the License.
#include "oneflow/core/ep/include/primitive/binary_op.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/common/scalar.h"
#include <cmath>
namespace oneflow {
......@@ -124,6 +125,39 @@ struct BinaryFunctor<device, BinaryOp::kGreaterEqual, Src, Dst> {
OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast<Dst>(src0 >= src1); }
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kIsCloseEqualNan, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1)
: atol(attr0.Value<float>()), rtol(attr1.Value<float>()) {}
OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const {
bool close = src0 == src1;
close |= (std::isnan(src0) and std::isnan(src1));
if (atol == 0 and rtol == 0) return close;
Src allowed_error = static_cast<Src>(atol) + abs(static_cast<Src>(rtol) * src1);
Src actual_error = abs(src0 - src1);
close |= (std::isfinite(actual_error) and (actual_error <= allowed_error));
return close;
}
float atol, rtol;
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kIsClose, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1)
: atol(attr0.Value<float>()), rtol(attr1.Value<float>()) {}
OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const {
bool close = src0 == src1;
if (atol == 0 and rtol == 0) return close;
Src allowed_error = static_cast<Src>(atol) + abs(static_cast<Src>(rtol) * src1);
Src actual_error = abs(src0 - src1);
close |= (std::isfinite(actual_error) and (actual_error <= allowed_error));
return close;
}
float atol, rtol;
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kLogicalAnd, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
......@@ -147,6 +181,81 @@ struct BinaryFunctor<device, BinaryOp::kLogicalXor, Src, Dst> {
}
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kFmod, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast<Dst>(src0 % src1); }
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kFloorDiv, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return src0 / src1; }
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kTruncDiv, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast<Dst>(src0 / src1); }
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kFloorMod, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const {
Src trunc_mod = src0 % src1;
return (trunc_mod != static_cast<Src>(0))
&& ((src1 < static_cast<Src>(0)) != (trunc_mod < static_cast<Src>(0)))
? trunc_mod + src1
: trunc_mod;
}
};
template<DeviceType device>
struct BinaryFunctor<device, BinaryOp::kFloorMod, uint8_t, uint8_t> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC uint8_t operator()(uint8_t src0, uint8_t src1) const { return src0 % src1; }
};
template<DeviceType device>
struct BinaryFunctor<device, BinaryOp::kFloorMod, uint32_t, uint32_t> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC uint32_t operator()(uint32_t src0, uint32_t src1) const { return src0 % src1; }
};
template<DeviceType device>
struct BinaryFunctor<device, BinaryOp::kFloorMod, uint64_t, uint64_t> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC uint64_t operator()(uint64_t src0, uint64_t src1) const { return src0 % src1; }
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kScalarBasePowerGrad, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : scalar_operand(attr0.Value<Src>()) {}
OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const {
return scalar_operand * (pow(src0, scalar_operand - static_cast<Src>(1))) * src1;
}
Src scalar_operand;
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kScalarExpPowerGrad, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : scalar_operand(attr0.Value<Src>()) {}
OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const {
return (pow(scalar_operand, src0)) * log(scalar_operand) * src1;
}
Src scalar_operand;
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kEluBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : alpha(attr0.Value<double>()) {}
......@@ -314,6 +423,226 @@ struct BinaryFunctor<device, BinaryOp::kThresholdBackwardWithDyX, Src, Dst> {
const Src threshold;
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kAbsBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
const Src zero = static_cast<Src>(0.0);
if (x == zero) {
return zero;
} else if (x < zero) {
return -dy;
} else {
return dy;
}
}
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kAcosBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
return dy * -rsqrt(static_cast<Src>(1.0) - x * x);
}
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kAcoshBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
return dy * rsqrt(x * x - static_cast<Src>(1.0));
}
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kAsinBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
return dy * rsqrt(static_cast<Src>(1.0) - x * x);
}
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kAsinhBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
return dy * rsqrt(static_cast<Src>(1.0) + x * x);
}
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kAtanBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
const Src one = static_cast<Src>(1.0);
return dy * (one / (one + x * x));
}
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kAtanhBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
const Src one = static_cast<Src>(1.0);
return dy * (one / (one - x * x));
}
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kCosBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * (-sin(x)); }
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kCoshBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * sinh(x); }
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kErfBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
return dy * static_cast<Src>(2.0) * rsqrt(static_cast<Src>(M_PI)) * exp(-x * x);
}
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kErfcBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
return dy * -static_cast<Src>(2.0) * rsqrt(static_cast<Src>(M_PI)) * exp(-x * x);
}
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kExpBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * exp(x); }
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kExpm1BackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * exp(x); }
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kLgammaBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
// TODO(chengcheng): return: dy * digamma(x)
assert(false);
return 0.0;
}
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kLogBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * (static_cast<Src>(1.0) / x); }
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kLog2BackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
return dy * (static_cast<Src>(1.0) / (x * log(static_cast<Src>(2.0))));
}
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kLog10BackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
return dy * (static_cast<Src>(1.0) / (x * log(static_cast<Src>(10.0))));
}
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kLog1pBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
return dy * (static_cast<Src>(1.0) / (x + static_cast<Src>(1.0)));
}
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kLogSigmoidBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
return dy * (static_cast<Src>(1.0) / (exp(x) + static_cast<Src>(1.0)));
}
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kReciprocalBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
return dy * (-static_cast<Src>(1.0) / (x * x));
}
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kReciprocalNoNanBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
if (abs(x) <= static_cast<Src>(0.0)) { return static_cast<Dst>(0.0); }
return dy * (-static_cast<Src>(1.0) / (x * x));
}
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kRsqrtBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
return dy * (static_cast<Src>(-1.0) / (static_cast<Src>(2.0) * sqrt(x * x * x)));
}
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kSigmoidBackwardWithDyY, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src y) const { return dy * (y * (1.0 - y)); }
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kSinBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * cos(x); }
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kSinhBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * cosh(x); }
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kSqrtBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
return dy * static_cast<Src>(0.5) / sqrt(x);
}
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kSquareBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const { return dy * static_cast<Src>(2.0) * x; }
};
template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kTanBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
const Src cos_val = cos(x);
return dy * (static_cast<Src>(1.0) / (cos_val * cos_val));
}
};
} // namespace broadcast_elementwise_binary
} // namespace primitive
} // namespace ep
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment