Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Oneflow
Commits
a715222c
Commit
a715222c
authored
Feb 28, 2023
by
yuguo
Browse files
0.9.1-rocm
parent
f262efc9
Changes
469
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1857 additions
and
2052 deletions
+1857
-2052
oneflow/core/embedding/embedding_manager.cpp
oneflow/core/embedding/embedding_manager.cpp
+705
-2
oneflow/core/embedding/embedding_manager.h
oneflow/core/embedding/embedding_manager.h
+137
-23
oneflow/core/embedding/full_cache.cu
oneflow/core/embedding/full_cache.cu
+170
-67
oneflow/core/embedding/full_cache.h
oneflow/core/embedding/full_cache.h
+1
-6
oneflow/core/embedding/full_cache.hip.cpp
oneflow/core/embedding/full_cache.hip.cpp
+0
-640
oneflow/core/embedding/hash_functions.hip.h
oneflow/core/embedding/hash_functions.hip.h
+0
-100
oneflow/core/embedding/key_value_store_test.cpp
oneflow/core/embedding/key_value_store_test.cpp
+236
-0
oneflow/core/embedding/lru_cache.cu
oneflow/core/embedding/lru_cache.cu
+211
-56
oneflow/core/embedding/lru_cache.hip.cpp
oneflow/core/embedding/lru_cache.hip.cpp
+0
-585
oneflow/core/embedding/mock_key_value_store.cu
oneflow/core/embedding/mock_key_value_store.cu
+24
-23
oneflow/core/embedding/mock_key_value_store.h
oneflow/core/embedding/mock_key_value_store.h
+1
-11
oneflow/core/embedding/mock_key_value_store.hip.cpp
oneflow/core/embedding/mock_key_value_store.hip.cpp
+0
-249
oneflow/core/embedding/persistent_table.cpp
oneflow/core/embedding/persistent_table.cpp
+13
-8
oneflow/core/embedding/persistent_table.h
oneflow/core/embedding/persistent_table.h
+1
-0
oneflow/core/embedding/persistent_table_key_value_store.cu
oneflow/core/embedding/persistent_table_key_value_store.cu
+24
-23
oneflow/core/embedding/persistent_table_key_value_store.h
oneflow/core/embedding/persistent_table_key_value_store.h
+1
-11
oneflow/core/embedding/persistent_table_key_value_store.hip.cpp
...w/core/embedding/persistent_table_key_value_store.hip.cpp
+0
-243
oneflow/core/embedding/posix_file.h
oneflow/core/embedding/posix_file.h
+4
-4
oneflow/core/ep/common/primitive/batch_matmul.cpp
oneflow/core/ep/common/primitive/batch_matmul.cpp
+0
-1
oneflow/core/ep/common/primitive/binary_functor.h
oneflow/core/ep/common/primitive/binary_functor.h
+329
-0
No files found.
Too many changes to show.
To preserve performance only
469 of 469+
files are displayed.
Plain diff
Email patch
oneflow/core/embedding/embedding_manager.cpp
View file @
a715222c
...
...
@@ -24,7 +24,470 @@ namespace embedding {
#ifdef WITH_CUDA
constexpr
size_t
kDefaultMaxQueryLength
=
65536
;
constexpr
size_t
kDefaultMaxQueryLength
=
131072
;
constexpr
int64_t
kRingBufferSize
=
8
;
struct
IdStatistics
{
IdStatistics
()
:
final_num_unique
(
0
),
iter
(
-
1
)
{}
uint32_t
final_num_unique
;
std
::
vector
<
uint32_t
>
num_unique_matrix
;
int64_t
iter
;
};
#if CUDA_VERSION >= 11020
class
DynamicTmpBufferAllocator
final
:
public
TmpBufferAllocator
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
DynamicTmpBufferAllocator
);
DynamicTmpBufferAllocator
(
cudaStream_t
stream
,
cudaMemPool_t
pool
)
:
stream_
(
stream
),
mem_pool_
(
pool
)
{}
~
DynamicTmpBufferAllocator
()
override
=
default
;
void
Allocate
(
void
**
ptr
,
size_t
size
)
override
{
OF_CUDA_CHECK
(
cudaMallocFromPoolAsync
(
ptr
,
GetCudaAlignedSize
(
size
),
mem_pool_
,
stream_
));
}
void
Free
(
void
*
ptr
)
override
{
OF_CUDA_CHECK
(
cudaFreeAsync
(
ptr
,
stream_
));
}
private:
cudaStream_t
stream_
{};
cudaMemPool_t
mem_pool_
{};
};
class
DynamicAllocationEmbeddingState
final
:
public
EmbeddingState
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
DynamicAllocationEmbeddingState
);
DynamicAllocationEmbeddingState
()
:
lookup_values_
(
nullptr
),
lookup_values_size_
(
0
),
has_lookup_values_
(
false
),
lookup_embeddings_
(
nullptr
),
lookup_embeddings_size_
(
0
),
has_lookup_embeddings_
(
false
),
updated_values_
(
nullptr
),
iter_
(
-
1
)
{
OF_CUDA_CHECK
(
cudaGetDevice
(
&
device_index_
));
id_statistics_vec_
.
resize
(
kRingBufferSize
);
cudaMemPoolProps
poolProps
=
{};
poolProps
.
allocType
=
cudaMemAllocationTypePinned
;
poolProps
.
handleTypes
=
cudaMemHandleTypePosixFileDescriptor
;
poolProps
.
location
.
type
=
cudaMemLocationTypeDevice
;
poolProps
.
location
.
id
=
device_index_
;
cudaMemPoolCreate
(
&
mem_pool_
,
&
poolProps
);
uint64_t
threshold
=
UINT64_MAX
;
cudaMemPoolSetAttribute
(
mem_pool_
,
cudaMemPoolAttrReleaseThreshold
,
&
threshold
);
}
~
DynamicAllocationEmbeddingState
()
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
if
(
has_lookup_values_
)
{
OF_CUDA_CHECK
(
cudaFree
(
lookup_values_
));
}
if
(
has_lookup_embeddings_
)
{
OF_CUDA_CHECK
(
cudaFree
(
lookup_embeddings_
));
}
OF_CUDA_CHECK
(
cudaMemPoolDestroy
(
mem_pool_
));
}
std
::
unique_ptr
<
TmpBufferAllocator
>
NewTmpBufferAllocator
(
user_op
::
KernelComputeContext
*
ctx
)
override
{
return
std
::
make_unique
<
DynamicTmpBufferAllocator
>
(
ctx
->
stream
()
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
(),
mem_pool_
);
}
void
OnEmbeddingLookupStart
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
iter_
=
iter
;
cudaStream_t
cuda_stream
=
ctx
->
stream
()
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
();
user_op
::
Tensor
*
unique_values
=
ctx
->
Tensor4ArgNameAndIndex
(
"unique_values"
,
0
);
const
int64_t
embedding_size
=
ctx
->
Attr
<
int64_t
>
(
"embedding_size"
);
const
int64_t
line_size
=
ctx
->
Attr
<
int64_t
>
(
"line_size"
);
uint32_t
num_unique
=
this
->
GetIdNumUnique
(
iter
);
size_t
lookup_values_size
=
GetCudaAlignedSize
(
num_unique
*
line_size
*
GetSizeOfDataType
(
unique_values
->
data_type
()));
if
(
!
has_lookup_values_
||
lookup_values_size_
<
lookup_values_size
)
{
if
(
has_lookup_values_
)
{
OF_CUDA_CHECK
(
cudaFreeAsync
(
lookup_values_
,
cuda_stream
));
}
OF_CUDA_CHECK
(
cudaMallocFromPoolAsync
(
&
lookup_values_
,
lookup_values_size
,
mem_pool_
,
cuda_stream
));
has_lookup_values_
=
true
;
lookup_values_size_
=
lookup_values_size
;
if
(
ctx
->
has_output
(
"embeddings"
,
0
))
{
user_op
::
Tensor
*
embeddings
=
ctx
->
Tensor4ArgNameAndIndex
(
"embeddings"
,
0
);
const
size_t
lookup_embeddings_size
=
GetCudaAlignedSize
(
num_unique
*
embedding_size
*
GetSizeOfDataType
(
embeddings
->
data_type
()));
if
(
!
has_lookup_embeddings_
||
lookup_embeddings_size_
<
lookup_values_size
)
{
if
(
has_lookup_embeddings_
)
{
OF_CUDA_CHECK
(
cudaFreeAsync
(
lookup_embeddings_
,
cuda_stream
));
}
OF_CUDA_CHECK
(
cudaMallocFromPoolAsync
(
&
lookup_embeddings_
,
lookup_embeddings_size
,
mem_pool_
,
cuda_stream
));
has_lookup_embeddings_
=
true
;
lookup_embeddings_size_
=
lookup_embeddings_size
;
}
}
else
{
lookup_embeddings_
=
nullptr
;
}
}
}
void
*
LookupUniqueValues
(
int64_t
iter
)
override
{
CHECK_EQ
(
iter_
,
iter
);
CHECK
(
has_lookup_values_
);
return
lookup_values_
;
}
void
*
LookupEmbeddings
(
int64_t
iter
)
override
{
CHECK_EQ
(
iter_
,
iter
);
CHECK
(
has_lookup_embeddings_
);
return
lookup_embeddings_
;
}
void
OnEmbeddingLookupEnd
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
// do nothing
}
void
OnEmbeddingGatherStart
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
// do nothing
}
const
void
*
EmbeddingGatherIn
(
int64_t
iter
)
override
{
if
(
has_lookup_embeddings_
)
{
return
lookup_embeddings_
;
}
else
{
CHECK
(
has_lookup_values_
);
return
lookup_values_
;
}
}
void
OnEmbeddingGatherEnd
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
// do nothing
}
void
OnEmbeddingShuffleStart
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
// do nothing
}
const
void
*
EmbeddingShuffleCurRankEmbeddings
(
int64_t
iter
)
override
{
if
(
has_lookup_embeddings_
)
{
return
lookup_embeddings_
;
}
else
{
CHECK
(
has_lookup_values_
);
return
lookup_values_
;
}
}
void
OnEmbeddingShuffleEnd
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
// do nothing
}
void
OnEmbeddingUpdateStart
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
const
user_op
::
Tensor
*
updated_unique_embeddings
=
ctx
->
Tensor4ArgNameAndIndex
(
"updated_unique_embeddings"
,
0
);
const
int64_t
line_size
=
ctx
->
Attr
<
int64_t
>
(
"line_size"
);
uint32_t
num_unique
=
this
->
GetIdNumUnique
(
iter
);
size_t
update_values_size
=
GetCudaAlignedSize
(
num_unique
*
line_size
*
GetSizeOfDataType
(
updated_unique_embeddings
->
data_type
()));
OF_CUDA_CHECK
(
cudaMallocFromPoolAsync
(
&
updated_values_
,
update_values_size
,
mem_pool_
,
ctx
->
stream
()
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()));
}
const
void
*
EmbeddingUpdateUniqueEmbeddings
(
int64_t
iter
)
override
{
CHECK_EQ
(
iter_
,
iter
);
CHECK
(
has_lookup_values_
);
return
lookup_values_
;
}
void
*
EmbeddingUpdateUpdatedUniqueEmbeddings
(
int64_t
iter
)
override
{
CHECK_EQ
(
iter_
,
iter
);
return
updated_values_
;
}
void
OnEmbeddingUpdateEnd
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
// do nothing
}
void
OnEmbeddingPutStart
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
// do nothing
}
const
void
*
EmbeddingPutUniqueEmbeddings
(
int64_t
iter
)
override
{
CHECK_EQ
(
iter_
,
iter
);
return
updated_values_
;
}
void
OnEmbeddingPutEnd
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
OF_CUDA_CHECK
(
cudaFreeAsync
(
updated_values_
,
ctx
->
stream
()
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()));
}
void
OnEmbeddingFusedUpdatePutStart
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
// do nothing
}
const
void
*
EmbeddingFusedUpdatePutUniqueEmbeddings
(
int64_t
iter
)
override
{
CHECK_EQ
(
iter_
,
iter
);
CHECK
(
has_lookup_values_
);
return
lookup_values_
;
}
void
OnEmbeddingFusedUpdatePutEnd
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
// do nothing
}
void
SetIdFinalNumUnique
(
uint32_t
final_num_unique
,
int64_t
iter
)
override
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
int64_t
index
=
iter
%
kRingBufferSize
;
id_statistics_vec_
.
at
(
index
).
final_num_unique
=
final_num_unique
;
id_statistics_vec_
.
at
(
index
).
iter
=
iter
;
}
void
SetIdNumUniqueMatrix
(
const
std
::
vector
<
uint32_t
>&
num_unique_matrix
,
int64_t
iter
)
override
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
int64_t
index
=
iter
%
kRingBufferSize
;
id_statistics_vec_
.
at
(
index
).
num_unique_matrix
=
num_unique_matrix
;
id_statistics_vec_
.
at
(
index
).
iter
=
iter
;
}
uint32_t
GetIdNumUnique
(
int64_t
iter
)
override
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
int64_t
index
=
iter
%
kRingBufferSize
;
const
IdStatistics
&
statistics
=
id_statistics_vec_
.
at
(
index
);
CHECK_EQ
(
statistics
.
iter
,
iter
)
<<
"saved iter: "
<<
statistics
.
iter
<<
" current iter: "
<<
iter
;
return
statistics
.
final_num_unique
;
}
const
std
::
vector
<
uint32_t
>&
GetIdNumUniqueMatrix
(
int64_t
iter
)
override
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
int64_t
index
=
iter
%
kRingBufferSize
;
const
IdStatistics
&
statistics
=
id_statistics_vec_
.
at
(
index
);
CHECK_EQ
(
statistics
.
iter
,
iter
)
<<
"saved iter: "
<<
statistics
.
iter
<<
" current iter: "
<<
iter
;
return
statistics
.
num_unique_matrix
;
}
private:
void
*
lookup_values_
;
size_t
lookup_values_size_
;
bool
has_lookup_values_
;
void
*
lookup_embeddings_
;
size_t
lookup_embeddings_size_
;
bool
has_lookup_embeddings_
;
void
*
updated_values_
;
int64_t
iter_
;
std
::
vector
<
IdStatistics
>
id_statistics_vec_
;
int
device_index_
{};
cudaMemPool_t
mem_pool_
{};
std
::
mutex
mutex_
;
};
#endif
class
StaticTmpBufferAllocator
final
:
public
TmpBufferAllocator
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
StaticTmpBufferAllocator
);
StaticTmpBufferAllocator
(
void
*
ptr
,
size_t
size
)
:
ptr_
(
ptr
),
offset_
(
0
),
size_
(
size
)
{}
~
StaticTmpBufferAllocator
()
override
=
default
;
void
Allocate
(
void
**
ptr
,
size_t
size
)
override
{
CHECK
(
ptr_
!=
nullptr
);
CHECK_GE
(
offset_
,
0
);
size_t
aligned_size
=
GetCudaAlignedSize
(
size
);
CHECK_LE
(
offset_
+
aligned_size
,
size_
);
*
ptr
=
reinterpret_cast
<
char
*>
(
ptr_
)
+
offset_
;
offset_
+=
aligned_size
;
}
void
Free
(
void
*
ptr
)
override
{
// do nothing
}
private:
void
*
ptr_
;
int64_t
offset_
;
size_t
size_
;
};
class
StaticAllocationEmbeddingState
final
:
public
EmbeddingState
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
StaticAllocationEmbeddingState
);
StaticAllocationEmbeddingState
()
:
lookup_unique_values_
(
nullptr
),
lookup_embeddings_
(
nullptr
),
has_lookup_embeddings_
(
false
),
embedding_shuffle_cur_rank_embeddings_
(
nullptr
),
embedding_update_unique_embeddings_
(
nullptr
),
embedding_update_updated_unique_embeddings_
(
nullptr
),
embedding_put_unique_embeddings_
(
nullptr
),
embedding_fused_update_put_unique_embeddings_
(
nullptr
)
{
id_statistics_vec_
.
resize
(
kRingBufferSize
);
}
~
StaticAllocationEmbeddingState
()
override
=
default
;
std
::
unique_ptr
<
TmpBufferAllocator
>
NewTmpBufferAllocator
(
user_op
::
KernelComputeContext
*
ctx
)
override
{
user_op
::
Tensor
*
tmp_buffer
=
ctx
->
Tensor4ArgNameAndIndex
(
"tmp_buffer"
,
0
);
return
std
::
make_unique
<
StaticTmpBufferAllocator
>
(
tmp_buffer
->
mut_dptr
(),
tmp_buffer
->
shape_view
().
elem_cnt
());
}
void
OnEmbeddingLookupStart
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
user_op
::
Tensor
*
unique_values
=
ctx
->
Tensor4ArgNameAndIndex
(
"unique_values"
,
0
);
lookup_unique_values_
=
unique_values
->
mut_dptr
();
if
(
ctx
->
has_output
(
"embeddings"
,
0
))
{
user_op
::
Tensor
*
embeddings
=
ctx
->
Tensor4ArgNameAndIndex
(
"embeddings"
,
0
);
has_lookup_embeddings_
=
true
;
lookup_embeddings_
=
embeddings
->
mut_dptr
();
}
}
void
*
LookupUniqueValues
(
int64_t
iter
)
override
{
return
lookup_unique_values_
;
}
void
*
LookupEmbeddings
(
int64_t
iter
)
override
{
CHECK
(
has_lookup_embeddings_
);
return
lookup_embeddings_
;
}
void
OnEmbeddingLookupEnd
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
lookup_unique_values_
=
nullptr
;
lookup_embeddings_
=
nullptr
;
has_lookup_embeddings_
=
false
;
}
void
OnEmbeddingGatherStart
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
const
user_op
::
Tensor
*
in
=
ctx
->
Tensor4ArgNameAndIndex
(
"in"
,
0
);
embedding_gather_in_
=
in
->
dptr
();
}
const
void
*
EmbeddingGatherIn
(
int64_t
iter
)
override
{
return
embedding_gather_in_
;
}
void
OnEmbeddingGatherEnd
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
embedding_gather_in_
=
nullptr
;
}
void
OnEmbeddingShuffleStart
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
const
user_op
::
Tensor
*
cur_rank_embeddings
=
ctx
->
Tensor4ArgNameAndIndex
(
"cur_rank_embeddings"
,
0
);
embedding_shuffle_cur_rank_embeddings_
=
cur_rank_embeddings
->
dptr
();
}
const
void
*
EmbeddingShuffleCurRankEmbeddings
(
int64_t
iter
)
override
{
return
embedding_shuffle_cur_rank_embeddings_
;
}
void
OnEmbeddingShuffleEnd
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
embedding_shuffle_cur_rank_embeddings_
=
nullptr
;
}
void
OnEmbeddingUpdateStart
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
const
user_op
::
Tensor
*
unique_embeddings
=
ctx
->
Tensor4ArgNameAndIndex
(
"unique_embeddings"
,
0
);
user_op
::
Tensor
*
updated_unique_embeddings
=
ctx
->
Tensor4ArgNameAndIndex
(
"updated_unique_embeddings"
,
0
);
embedding_update_unique_embeddings_
=
unique_embeddings
->
dptr
();
embedding_update_updated_unique_embeddings_
=
updated_unique_embeddings
->
mut_dptr
();
}
const
void
*
EmbeddingUpdateUniqueEmbeddings
(
int64_t
iter
)
override
{
return
embedding_update_unique_embeddings_
;
}
void
*
EmbeddingUpdateUpdatedUniqueEmbeddings
(
int64_t
iter
)
override
{
return
embedding_update_updated_unique_embeddings_
;
}
void
OnEmbeddingUpdateEnd
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
embedding_update_unique_embeddings_
=
nullptr
;
embedding_update_updated_unique_embeddings_
=
nullptr
;
}
void
OnEmbeddingPutStart
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
const
user_op
::
Tensor
*
unique_embeddings
=
ctx
->
Tensor4ArgNameAndIndex
(
"unique_embeddings"
,
0
);
embedding_put_unique_embeddings_
=
unique_embeddings
->
dptr
();
}
const
void
*
EmbeddingPutUniqueEmbeddings
(
int64_t
iter
)
override
{
return
embedding_put_unique_embeddings_
;
}
void
OnEmbeddingPutEnd
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
embedding_put_unique_embeddings_
=
nullptr
;
}
void
OnEmbeddingFusedUpdatePutStart
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
const
user_op
::
Tensor
*
unique_embeddings
=
ctx
->
Tensor4ArgNameAndIndex
(
"unique_embeddings"
,
0
);
embedding_fused_update_put_unique_embeddings_
=
unique_embeddings
->
dptr
();
}
const
void
*
EmbeddingFusedUpdatePutUniqueEmbeddings
(
int64_t
iter
)
override
{
return
embedding_fused_update_put_unique_embeddings_
;
}
void
OnEmbeddingFusedUpdatePutEnd
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
embedding_fused_update_put_unique_embeddings_
=
nullptr
;
}
void
SetIdFinalNumUnique
(
uint32_t
final_num_unique
,
int64_t
iter
)
override
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
int64_t
index
=
iter
%
kRingBufferSize
;
id_statistics_vec_
.
at
(
index
).
final_num_unique
=
final_num_unique
;
id_statistics_vec_
.
at
(
index
).
iter
=
iter
;
}
void
SetIdNumUniqueMatrix
(
const
std
::
vector
<
uint32_t
>&
num_unique_matrix
,
int64_t
iter
)
override
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
int64_t
index
=
iter
%
kRingBufferSize
;
id_statistics_vec_
.
at
(
index
).
num_unique_matrix
=
num_unique_matrix
;
id_statistics_vec_
.
at
(
index
).
iter
=
iter
;
}
uint32_t
GetIdNumUnique
(
int64_t
iter
)
override
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
int64_t
index
=
iter
%
kRingBufferSize
;
const
IdStatistics
&
statistics
=
id_statistics_vec_
.
at
(
index
);
CHECK_EQ
(
statistics
.
iter
,
iter
)
<<
"saved iter: "
<<
statistics
.
iter
<<
" current iter: "
<<
iter
;
return
statistics
.
final_num_unique
;
}
const
std
::
vector
<
uint32_t
>&
GetIdNumUniqueMatrix
(
int64_t
iter
)
override
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
int64_t
index
=
iter
%
kRingBufferSize
;
const
IdStatistics
&
statistics
=
id_statistics_vec_
.
at
(
index
);
CHECK_EQ
(
statistics
.
iter
,
iter
)
<<
"saved iter: "
<<
statistics
.
iter
<<
" current iter: "
<<
iter
;
return
statistics
.
num_unique_matrix
;
}
void
*
lookup_unique_values_
;
void
*
lookup_embeddings_
;
bool
has_lookup_embeddings_
;
const
void
*
embedding_gather_in_
;
const
void
*
embedding_shuffle_cur_rank_embeddings_
;
const
void
*
embedding_update_unique_embeddings_
;
void
*
embedding_update_updated_unique_embeddings_
;
const
void
*
embedding_put_unique_embeddings_
;
const
void
*
embedding_fused_update_put_unique_embeddings_
;
std
::
vector
<
IdStatistics
>
id_statistics_vec_
;
std
::
mutex
mutex_
;
};
EmbeddingState
*
EmbeddingManager
::
GetEmbeddingState
(
const
std
::
string
&
embedding_name
,
int64_t
rank_id
)
{
std
::
pair
<
std
::
string
,
int64_t
>
map_key
=
std
::
make_pair
(
embedding_name
,
rank_id
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
auto
it
=
embedding_state_map_
.
find
(
map_key
);
// for id shuffle test, not need to create table
if
(
it
==
embedding_state_map_
.
end
())
{
LOG
(
INFO
)
<<
"create embedding state: "
<<
embedding_name
<<
"-"
<<
rank_id
;
if
(
UseDynamicMemoryAllocation
())
{
#if CUDA_VERSION >= 11020
it
=
embedding_state_map_
.
emplace
(
map_key
,
std
::
make_unique
<
DynamicAllocationEmbeddingState
>
())
.
first
;
#else
UNIMPLEMENTED
();
#endif
}
else
{
it
=
embedding_state_map_
.
emplace
(
map_key
,
std
::
make_unique
<
StaticAllocationEmbeddingState
>
())
.
first
;
}
}
return
it
->
second
.
get
();
}
KeyValueStore
*
EmbeddingManager
::
GetKeyValueStore
(
const
std
::
string
&
embedding_name
,
int64_t
rank_id
)
{
...
...
@@ -66,6 +529,22 @@ void EmbeddingManager::CreateKeyValueStore(const KeyValueStoreOptions& key_value
store
->
ReserveQueryLength
(
kDefaultMaxQueryLength
);
CHECK
(
key_value_store_map_
.
emplace
(
map_key
,
std
::
move
(
store
)).
second
)
<<
"Can't create an embedding with same name of an existing embedding, the name: "
<<
name
;
if
(
UseDynamicMemoryAllocation
())
{
#if CUDA_VERSION >= 11020
CHECK
(
embedding_state_map_
.
emplace
(
map_key
,
std
::
make_unique
<
DynamicAllocationEmbeddingState
>
())
.
second
)
<<
"Can't create an embedding state with same name of an existing embedding, the name: "
<<
name
;
#else
UNIMPLEMENTED
();
#endif
}
else
{
CHECK
(
embedding_state_map_
.
emplace
(
map_key
,
std
::
make_unique
<
StaticAllocationEmbeddingState
>
())
.
second
)
<<
"Can't create an embedding state with same name of an existing embedding, the name: "
<<
name
;
}
}
void
EmbeddingManager
::
SaveSnapshot
(
const
std
::
string
&
embedding_name
,
int64_t
local_rank_id
,
...
...
@@ -101,6 +580,221 @@ void EmbeddingManager::LoadSnapshot(const std::string& embedding_name, int64_t l
constexpr
size_t
kDefaultMaxQueryLength
=
131072
;
constexpr
int64_t
kRingBufferSize
=
8
;
struct
IdStatistics
{
IdStatistics
()
:
final_num_unique
(
0
),
iter
(
-
1
)
{}
uint32_t
final_num_unique
;
std
::
vector
<
uint32_t
>
num_unique_matrix
;
int64_t
iter
;
};
class
StaticTmpBufferAllocator
final
:
public
TmpBufferAllocator
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
StaticTmpBufferAllocator
);
StaticTmpBufferAllocator
(
void
*
ptr
,
size_t
size
)
:
ptr_
(
ptr
),
offset_
(
0
),
size_
(
size
)
{}
~
StaticTmpBufferAllocator
()
override
=
default
;
void
Allocate
(
void
**
ptr
,
size_t
size
)
override
{
CHECK
(
ptr_
!=
nullptr
);
CHECK_GE
(
offset_
,
0
);
size_t
aligned_size
=
GetCudaAlignedSize
(
size
);
CHECK_LE
(
offset_
+
aligned_size
,
size_
);
*
ptr
=
reinterpret_cast
<
char
*>
(
ptr_
)
+
offset_
;
offset_
+=
aligned_size
;
}
void
Free
(
void
*
ptr
)
override
{
// do nothing
}
private:
void
*
ptr_
;
int64_t
offset_
;
size_t
size_
;
};
class
StaticAllocationEmbeddingState
final
:
public
EmbeddingState
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
StaticAllocationEmbeddingState
);
StaticAllocationEmbeddingState
()
:
lookup_unique_values_
(
nullptr
),
lookup_embeddings_
(
nullptr
),
has_lookup_embeddings_
(
false
),
embedding_shuffle_cur_rank_embeddings_
(
nullptr
),
embedding_update_unique_embeddings_
(
nullptr
),
embedding_update_updated_unique_embeddings_
(
nullptr
),
embedding_put_unique_embeddings_
(
nullptr
),
embedding_fused_update_put_unique_embeddings_
(
nullptr
)
{
id_statistics_vec_
.
resize
(
kRingBufferSize
);
}
~
StaticAllocationEmbeddingState
()
override
=
default
;
std
::
unique_ptr
<
TmpBufferAllocator
>
NewTmpBufferAllocator
(
user_op
::
KernelComputeContext
*
ctx
)
override
{
user_op
::
Tensor
*
tmp_buffer
=
ctx
->
Tensor4ArgNameAndIndex
(
"tmp_buffer"
,
0
);
return
std
::
make_unique
<
StaticTmpBufferAllocator
>
(
tmp_buffer
->
mut_dptr
(),
tmp_buffer
->
shape_view
().
elem_cnt
());
}
void
OnEmbeddingLookupStart
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
user_op
::
Tensor
*
unique_values
=
ctx
->
Tensor4ArgNameAndIndex
(
"unique_values"
,
0
);
lookup_unique_values_
=
unique_values
->
mut_dptr
();
if
(
ctx
->
has_output
(
"embeddings"
,
0
))
{
user_op
::
Tensor
*
embeddings
=
ctx
->
Tensor4ArgNameAndIndex
(
"embeddings"
,
0
);
has_lookup_embeddings_
=
true
;
lookup_embeddings_
=
embeddings
->
mut_dptr
();
}
}
void
*
LookupUniqueValues
(
int64_t
iter
)
override
{
return
lookup_unique_values_
;
}
void
*
LookupEmbeddings
(
int64_t
iter
)
override
{
CHECK
(
has_lookup_embeddings_
);
return
lookup_embeddings_
;
}
void
OnEmbeddingLookupEnd
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
lookup_unique_values_
=
nullptr
;
lookup_embeddings_
=
nullptr
;
has_lookup_embeddings_
=
false
;
}
void
OnEmbeddingGatherStart
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
const
user_op
::
Tensor
*
in
=
ctx
->
Tensor4ArgNameAndIndex
(
"in"
,
0
);
embedding_gather_in_
=
in
->
dptr
();
}
const
void
*
EmbeddingGatherIn
(
int64_t
iter
)
override
{
return
embedding_gather_in_
;
}
void
OnEmbeddingGatherEnd
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
embedding_gather_in_
=
nullptr
;
}
void
OnEmbeddingShuffleStart
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
const
user_op
::
Tensor
*
cur_rank_embeddings
=
ctx
->
Tensor4ArgNameAndIndex
(
"cur_rank_embeddings"
,
0
);
embedding_shuffle_cur_rank_embeddings_
=
cur_rank_embeddings
->
dptr
();
}
const
void
*
EmbeddingShuffleCurRankEmbeddings
(
int64_t
iter
)
override
{
return
embedding_shuffle_cur_rank_embeddings_
;
}
void
OnEmbeddingShuffleEnd
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
embedding_shuffle_cur_rank_embeddings_
=
nullptr
;
}
void
OnEmbeddingUpdateStart
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
const
user_op
::
Tensor
*
unique_embeddings
=
ctx
->
Tensor4ArgNameAndIndex
(
"unique_embeddings"
,
0
);
user_op
::
Tensor
*
updated_unique_embeddings
=
ctx
->
Tensor4ArgNameAndIndex
(
"updated_unique_embeddings"
,
0
);
embedding_update_unique_embeddings_
=
unique_embeddings
->
dptr
();
embedding_update_updated_unique_embeddings_
=
updated_unique_embeddings
->
mut_dptr
();
}
const
void
*
EmbeddingUpdateUniqueEmbeddings
(
int64_t
iter
)
override
{
return
embedding_update_unique_embeddings_
;
}
void
*
EmbeddingUpdateUpdatedUniqueEmbeddings
(
int64_t
iter
)
override
{
return
embedding_update_updated_unique_embeddings_
;
}
void
OnEmbeddingUpdateEnd
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
embedding_update_unique_embeddings_
=
nullptr
;
embedding_update_updated_unique_embeddings_
=
nullptr
;
}
void
OnEmbeddingPutStart
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
const
user_op
::
Tensor
*
unique_embeddings
=
ctx
->
Tensor4ArgNameAndIndex
(
"unique_embeddings"
,
0
);
embedding_put_unique_embeddings_
=
unique_embeddings
->
dptr
();
}
const
void
*
EmbeddingPutUniqueEmbeddings
(
int64_t
iter
)
override
{
return
embedding_put_unique_embeddings_
;
}
void
OnEmbeddingPutEnd
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
embedding_put_unique_embeddings_
=
nullptr
;
}
void
OnEmbeddingFusedUpdatePutStart
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
const
user_op
::
Tensor
*
unique_embeddings
=
ctx
->
Tensor4ArgNameAndIndex
(
"unique_embeddings"
,
0
);
embedding_fused_update_put_unique_embeddings_
=
unique_embeddings
->
dptr
();
}
const
void
*
EmbeddingFusedUpdatePutUniqueEmbeddings
(
int64_t
iter
)
override
{
return
embedding_fused_update_put_unique_embeddings_
;
}
void
OnEmbeddingFusedUpdatePutEnd
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
override
{
embedding_fused_update_put_unique_embeddings_
=
nullptr
;
}
void
SetIdFinalNumUnique
(
uint32_t
final_num_unique
,
int64_t
iter
)
override
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
int64_t
index
=
iter
%
kRingBufferSize
;
id_statistics_vec_
.
at
(
index
).
final_num_unique
=
final_num_unique
;
id_statistics_vec_
.
at
(
index
).
iter
=
iter
;
}
void
SetIdNumUniqueMatrix
(
const
std
::
vector
<
uint32_t
>&
num_unique_matrix
,
int64_t
iter
)
override
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
int64_t
index
=
iter
%
kRingBufferSize
;
id_statistics_vec_
.
at
(
index
).
num_unique_matrix
=
num_unique_matrix
;
id_statistics_vec_
.
at
(
index
).
iter
=
iter
;
}
uint32_t
GetIdNumUnique
(
int64_t
iter
)
override
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
int64_t
index
=
iter
%
kRingBufferSize
;
const
IdStatistics
&
statistics
=
id_statistics_vec_
.
at
(
index
);
CHECK_EQ
(
statistics
.
iter
,
iter
)
<<
"saved iter: "
<<
statistics
.
iter
<<
" current iter: "
<<
iter
;
return
statistics
.
final_num_unique
;
}
const
std
::
vector
<
uint32_t
>&
GetIdNumUniqueMatrix
(
int64_t
iter
)
override
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
int64_t
index
=
iter
%
kRingBufferSize
;
const
IdStatistics
&
statistics
=
id_statistics_vec_
.
at
(
index
);
CHECK_EQ
(
statistics
.
iter
,
iter
)
<<
"saved iter: "
<<
statistics
.
iter
<<
" current iter: "
<<
iter
;
return
statistics
.
num_unique_matrix
;
}
void
*
lookup_unique_values_
;
void
*
lookup_embeddings_
;
bool
has_lookup_embeddings_
;
const
void
*
embedding_gather_in_
;
const
void
*
embedding_shuffle_cur_rank_embeddings_
;
const
void
*
embedding_update_unique_embeddings_
;
void
*
embedding_update_updated_unique_embeddings_
;
const
void
*
embedding_put_unique_embeddings_
;
const
void
*
embedding_fused_update_put_unique_embeddings_
;
std
::
vector
<
IdStatistics
>
id_statistics_vec_
;
std
::
mutex
mutex_
;
};
EmbeddingState
*
EmbeddingManager
::
GetEmbeddingState
(
const
std
::
string
&
embedding_name
,
int64_t
rank_id
)
{
std
::
pair
<
std
::
string
,
int64_t
>
map_key
=
std
::
make_pair
(
embedding_name
,
rank_id
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
auto
it
=
embedding_state_map_
.
find
(
map_key
);
// for id shuffle test, not need to create table
if
(
it
==
embedding_state_map_
.
end
())
{
LOG
(
INFO
)
<<
"create embedding state: "
<<
embedding_name
<<
"-"
<<
rank_id
;
if
(
UseDynamicMemoryAllocation
())
{
UNIMPLEMENTED
();
}
else
{
it
=
embedding_state_map_
.
emplace
(
map_key
,
std
::
make_unique
<
StaticAllocationEmbeddingState
>
())
.
first
;
}
}
return
it
->
second
.
get
();
}
KeyValueStore
*
EmbeddingManager
::
GetKeyValueStore
(
const
std
::
string
&
embedding_name
,
int64_t
rank_id
)
{
std
::
pair
<
std
::
string
,
int64_t
>
map_key
=
std
::
make_pair
(
embedding_name
,
rank_id
);
...
...
@@ -141,6 +835,15 @@ void EmbeddingManager::CreateKeyValueStore(const KeyValueStoreOptions& key_value
store
->
ReserveQueryLength
(
kDefaultMaxQueryLength
);
CHECK
(
key_value_store_map_
.
emplace
(
map_key
,
std
::
move
(
store
)).
second
)
<<
"Can't create an embedding with same name of an existing embedding, the name: "
<<
name
;
if
(
UseDynamicMemoryAllocation
())
{
UNIMPLEMENTED
();
}
else
{
CHECK
(
embedding_state_map_
.
emplace
(
map_key
,
std
::
make_unique
<
StaticAllocationEmbeddingState
>
())
.
second
)
<<
"Can't create an embedding state with same name of an existing embedding, the name: "
<<
name
;
}
}
void
EmbeddingManager
::
SaveSnapshot
(
const
std
::
string
&
embedding_name
,
int64_t
local_rank_id
,
...
...
@@ -170,7 +873,7 @@ void EmbeddingManager::LoadSnapshot(const std::string& embedding_name, int64_t l
}
}
#endif
// WITH_ROCM
#endif
}
// namespace embedding
...
...
oneflow/core/embedding/embedding_manager.h
View file @
a715222c
...
...
@@ -20,36 +20,149 @@ limitations under the License.
#include "oneflow/core/embedding/key_value_store.h"
#include "oneflow/core/embedding/key_value_store_options.h"
#include "oneflow/core/framework/framework.h"
namespace
oneflow
{
namespace
embedding
{
#ifdef WITH_CUDA
class
EmbeddingManager
final
{
inline
bool
UseDynamicMemoryAllocation
()
{
static
bool
use_dynamic_memory_allocation
=
ParseBooleanFromEnv
(
"ONEFLOW_ONE_EMBEDDING_USE_DYNAMIC_MEMORY_ALLOCATION"
,
false
);
#if CUDA_VERSION >= 11020
return
use_dynamic_memory_allocation
;
#else
if
(
use_dynamic_memory_allocation
)
{
LOG
(
WARNING
)
<<
"Dynamic memory allocation only support when cuda_version greater equal than 11.2. "
;
}
return
false
;
#endif
}
inline
bool
UseEmbeddingShuffleP2PKernel
(
DataType
embedding_dtype
,
DataType
idx_dtype
)
{
static
bool
use_embedding_shuffle_p2p_env
=
ParseBooleanFromEnv
(
"ONEFLOW_ONE_EMBEDDING_EMBEDDING_SHUFFLE_USE_P2P"
,
false
);
static
bool
add_id_shuffle_copy_out_env
=
ParseBooleanFromEnv
(
"ONEFLOW_ONE_EMBEDDING_ADD_ID_SHUFFLE_COPY_OUT"
,
true
);
static
bool
enable_quantized_comm
=
ParseBooleanFromEnv
(
"ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM"
,
false
);
if
(
use_embedding_shuffle_p2p_env
)
{
if
(
embedding_dtype
!=
DataType
::
kFloat16
||
idx_dtype
!=
DataType
::
kUInt32
)
{
// p2p kernel only registered kFloat16 and kUint32.
return
false
;
}
if
(
!
add_id_shuffle_copy_out_env
)
{
// when not enable id shuffle copy out, the ptrs change every iter.
return
false
;
}
if
(
enable_quantized_comm
)
{
// p2p kernel not support quantize comm.
return
false
;
}
if
(
UseDynamicMemoryAllocation
())
{
// p2p kernel not support dynamic memory allocation.
return
false
;
}
}
#if CUDA_VERSION >= 11030
return
use_embedding_shuffle_p2p_env
;
#else
if
(
use_embedding_shuffle_p2p_env
)
{
LOG
(
WARNING
)
<<
"embedding shuffle p2p kernel only support when cuda_version greater equal than 11.3. "
;
}
return
false
;
#endif
}
inline
bool
UseEmbeddingGradientShuffleP2PKernel
(
DataType
embedding_dtype
,
DataType
idx_dtype
)
{
static
bool
use_embedding_gradient_shuffle_p2p_env
=
ParseBooleanFromEnv
(
"ONEFLOW_ONE_EMBEDDING_EMBEDDING_GRADIENT_SHUFFLE_USE_P2P"
,
false
);
static
bool
add_id_shuffle_copy_out_env
=
ParseBooleanFromEnv
(
"ONEFLOW_ONE_EMBEDDING_ADD_ID_SHUFFLE_COPY_OUT"
,
true
);
static
bool
enable_quantized_comm
=
ParseBooleanFromEnv
(
"ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM"
,
false
);
if
(
use_embedding_gradient_shuffle_p2p_env
)
{
if
(
embedding_dtype
!=
DataType
::
kFloat16
||
idx_dtype
!=
DataType
::
kUInt32
)
{
// p2p kernel only registered kFloat16 and kUint32.
return
false
;
}
if
(
!
add_id_shuffle_copy_out_env
)
{
// when not enable id shuffle copy out, the ptrs change every iter.
return
false
;
}
if
(
enable_quantized_comm
)
{
// p2p kernel not support quantize comm.
return
false
;
}
if
(
UseDynamicMemoryAllocation
())
{
// p2p kernel not support dynamic memory allocation.
return
false
;
}
}
#if CUDA_VERSION >= 11030
return
use_embedding_gradient_shuffle_p2p_env
;
#else
if
(
use_embedding_gradient_shuffle_p2p_env
)
{
LOG
(
WARNING
)
<<
"embedding gradient shuffle p2p kernel only support when cuda_version greater "
"equal than 11.3. "
;
}
return
false
;
#endif
}
#if defined(WITH_CUDA) || defined(WITH_ROCM)
class
TmpBufferAllocator
{
public:
EmbeddingManager
()
=
default
;
~
EmbeddingManager
()
=
default
;
void
SaveSnapshot
(
const
std
::
string
&
embedding_name
,
int64_t
local_rank_id
,
int64_t
rank_id
,
const
std
::
string
&
snapshot_name
);
void
LoadSnapshot
(
const
std
::
string
&
embedding_name
,
int64_t
local_rank_id
,
int64_t
rank_id
,
const
std
::
string
&
snapshot_name
);
TmpBufferAllocator
()
=
default
;
virtual
~
TmpBufferAllocator
()
=
default
;
KeyValueStore
*
GetKeyValueStore
(
const
std
::
string
&
embedding_name
,
int64_t
rank_id
);
void
CreateKeyValueStore
(
const
KeyValueStoreOptions
&
options
,
int64_t
local_rank_id
,
int64_t
rank_id
,
int64_t
world_size
);
private:
HashMap
<
std
::
pair
<
std
::
string
,
int64_t
>
,
std
::
unique_ptr
<
KeyValueStore
>>
key_value_store_map_
;
std
::
mutex
mutex_
;
virtual
void
Allocate
(
void
**
ptr
,
size_t
size
)
=
0
;
virtual
void
Free
(
void
*
ptr
)
=
0
;
};
#endif // WITH_CUDA
#ifdef WITH_ROCM
class
EmbeddingState
{
public:
EmbeddingState
()
=
default
;
virtual
~
EmbeddingState
()
=
default
;
virtual
std
::
unique_ptr
<
TmpBufferAllocator
>
NewTmpBufferAllocator
(
user_op
::
KernelComputeContext
*
ctx
)
=
0
;
virtual
void
OnEmbeddingLookupStart
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
=
0
;
virtual
void
*
LookupUniqueValues
(
int64_t
iter
)
=
0
;
virtual
void
*
LookupEmbeddings
(
int64_t
iter
)
=
0
;
virtual
void
OnEmbeddingLookupEnd
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
=
0
;
virtual
void
OnEmbeddingGatherStart
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
=
0
;
virtual
const
void
*
EmbeddingGatherIn
(
int64_t
iter
)
=
0
;
virtual
void
OnEmbeddingGatherEnd
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
=
0
;
virtual
void
OnEmbeddingShuffleStart
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
=
0
;
virtual
const
void
*
EmbeddingShuffleCurRankEmbeddings
(
int64_t
iter
)
=
0
;
virtual
void
OnEmbeddingShuffleEnd
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
=
0
;
virtual
void
OnEmbeddingUpdateStart
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
=
0
;
virtual
const
void
*
EmbeddingUpdateUniqueEmbeddings
(
int64_t
iter
)
=
0
;
virtual
void
*
EmbeddingUpdateUpdatedUniqueEmbeddings
(
int64_t
iter
)
=
0
;
virtual
void
OnEmbeddingUpdateEnd
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
=
0
;
virtual
void
OnEmbeddingPutStart
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
=
0
;
virtual
const
void
*
EmbeddingPutUniqueEmbeddings
(
int64_t
iter
)
=
0
;
virtual
void
OnEmbeddingPutEnd
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
=
0
;
virtual
void
OnEmbeddingFusedUpdatePutStart
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
=
0
;
virtual
const
void
*
EmbeddingFusedUpdatePutUniqueEmbeddings
(
int64_t
iter
)
=
0
;
virtual
void
OnEmbeddingFusedUpdatePutEnd
(
user_op
::
KernelComputeContext
*
ctx
,
int64_t
iter
)
=
0
;
virtual
void
SetIdFinalNumUnique
(
uint32_t
final_num_unique
,
int64_t
iter
)
=
0
;
virtual
void
SetIdNumUniqueMatrix
(
const
std
::
vector
<
uint32_t
>&
num_unique_matrix
,
int64_t
iter
)
=
0
;
virtual
uint32_t
GetIdNumUnique
(
int64_t
iter
)
=
0
;
virtual
const
std
::
vector
<
uint32_t
>&
GetIdNumUniqueMatrix
(
int64_t
iter
)
=
0
;
};
class
EmbeddingManager
final
{
public:
...
...
@@ -62,16 +175,17 @@ class EmbeddingManager final {
const
std
::
string
&
snapshot_name
);
KeyValueStore
*
GetKeyValueStore
(
const
std
::
string
&
embedding_name
,
int64_t
rank_id
);
EmbeddingState
*
GetEmbeddingState
(
const
std
::
string
&
embedding_name
,
int64_t
rank_id
);
void
CreateKeyValueStore
(
const
KeyValueStoreOptions
&
options
,
int64_t
local_rank_id
,
int64_t
rank_id
,
int64_t
world_size
);
private:
HashMap
<
std
::
pair
<
std
::
string
,
int64_t
>
,
std
::
unique_ptr
<
KeyValueStore
>>
key_value_store_map_
;
HashMap
<
std
::
pair
<
std
::
string
,
int64_t
>
,
std
::
unique_ptr
<
EmbeddingState
>>
embedding_state_map_
;
std
::
mutex
mutex_
;
};
#endif // WITH_
ROCM
#endif // WITH_
CUDA
}
// namespace embedding
}
// namespace oneflow
...
...
oneflow/core/embedding/full_cache.cu
View file @
a715222c
...
...
@@ -28,9 +28,9 @@ using Key128 = ulonglong2;
namespace
{
template
<
typename
Key
,
typename
Index
>
__device__
bool
TryGetOrInsert
(
Key
*
entry_key
,
volatile
Index
*
entry_index
,
Index
*
table_size
,
Key
key
,
Index
*
out
)
{
template
<
typename
Key
,
typename
Index
,
bool
dump_dirty_only
>
__device__
bool
TryGetOrInsert
(
Key
*
entry_key
,
volatile
Index
*
entry_index
,
bool
*
entry_dirty_flag
,
Index
*
table_size
,
Key
key
,
Index
*
out
)
{
Key
key_hi
=
(
key
|
0x1
);
Key
key_lo
=
(
key
&
0x1
);
Index
index_plus_one
=
0
;
...
...
@@ -41,6 +41,10 @@ __device__ bool TryGetOrInsert(Key* entry_key, volatile Index* entry_index, Inde
index_plus_one
=
index
+
1
;
*
entry_index
=
((
index_plus_one
<<
1U
)
|
key_lo
);
*
out
=
index_plus_one
;
if
(
dump_dirty_only
)
{
bool
entry_flag_val
=
*
entry_dirty_flag
;
if
(
!
entry_flag_val
)
{
*
entry_dirty_flag
=
true
;
}
}
return
true
;
}
else
if
(
old_entry_key
==
key_hi
)
{
const
Index
entry_index_val
=
*
entry_index
;
...
...
@@ -48,6 +52,10 @@ __device__ bool TryGetOrInsert(Key* entry_key, volatile Index* entry_index, Inde
// do nothing
}
else
if
((
entry_index_val
&
0x1
)
==
key_lo
)
{
*
out
=
(
entry_index_val
>>
1U
);
if
(
dump_dirty_only
)
{
bool
entry_flag_val
=
*
entry_dirty_flag
;
if
(
!
entry_flag_val
)
{
*
entry_dirty_flag
=
true
;
}
}
return
true
;
}
else
{
return
false
;
...
...
@@ -59,15 +67,20 @@ __device__ bool TryGetOrInsert(Key* entry_key, volatile Index* entry_index, Inde
return
false
;
}
template
<
typename
Key
,
typename
Index
>
template
<
typename
Key
,
typename
Index
,
bool
dump_dirty_only
>
__device__
bool
GetOrInsertOne
(
const
size_t
capacity
,
Key
*
table_keys
,
Index
*
table_indices
,
Index
*
table_size
,
Key
key
,
size_t
hash
,
Index
*
out
)
{
bool
*
table_dirty_flags
,
Index
*
table_size
,
Key
key
,
size_t
hash
,
Index
*
out
)
{
const
size_t
start_idx
=
hash
%
capacity
;
for
(
size_t
count
=
0
;
count
<
capacity
;
++
count
)
{
const
size_t
idx
=
(
start_idx
+
count
)
%
capacity
;
Key
*
entry_key
=
table_keys
+
idx
;
Index
*
entry_index
=
table_indices
+
idx
;
if
(
TryGetOrInsert
<
Key
,
Index
>
(
entry_key
,
entry_index
,
table_size
,
key
,
out
))
{
return
true
;
}
bool
*
entry_dirty_flag
=
dump_dirty_only
?
table_dirty_flags
+
idx
:
nullptr
;
if
(
TryGetOrInsert
<
Key
,
Index
,
dump_dirty_only
>
(
entry_key
,
entry_index
,
entry_dirty_flag
,
table_size
,
key
,
out
))
{
return
true
;
}
}
return
false
;
}
...
...
@@ -94,15 +107,15 @@ __device__ bool GetOne(const size_t capacity, Key* table_keys, Index* table_indi
return
false
;
}
template
<
typename
Key
,
typename
Index
>
template
<
typename
Key
,
typename
Index
,
bool
dump_dirty_only
>
__global__
void
OrdinalEncodeKernel
(
uint64_t
capacity
,
Key
*
table_keys
,
Index
*
table_indices
,
Index
*
table_size
,
uint32_t
num_keys
,
const
Key
*
keys
,
Index
*
context
)
{
bool
*
table_dirty_flags
,
Index
*
table_size
,
uint32_t
num_keys
,
const
Key
*
keys
,
Index
*
context
)
{
CUDA_1D_KERNEL_LOOP
(
i
,
num_keys
)
{
Key
key
=
keys
[
i
];
uint64_t
hash
=
FullCacheHash
()(
key
);
bool
success
=
GetOrInsertOne
<
Key
,
Index
>
(
capacity
,
table_keys
,
table_indices
,
table_size
,
key
,
hash
,
context
+
i
);
bool
success
=
GetOrInsertOne
<
Key
,
Index
,
dump_dirty_only
>
(
capacity
,
table_keys
,
table_indices
,
table_dirty_flags
,
table_size
,
key
,
hash
,
context
+
i
);
assert
(
success
);
}
}
...
...
@@ -117,14 +130,20 @@ __global__ void OrdinalEncodeLookupKernel(uint64_t capacity, Key* table_keys, In
}
}
template
<
typename
Key
,
typename
Index
>
template
<
typename
Key
,
typename
Index
,
bool
dump_dirty_only
>
__global__
void
OrdinalEncodeDumpKernel
(
const
Key
*
table_keys
,
const
Index
*
table_indices
,
uint64_t
start_key_index
,
uint64_t
end_key_index
,
uint32_t
*
n_dumped
,
Key
*
keys
,
Index
*
context
)
{
const
bool
*
table_dirty_flags
,
uint64_t
start_key_index
,
uint64_t
end_key_index
,
uint32_t
*
n_dumped
,
Key
*
keys
,
Index
*
context
)
{
CUDA_1D_KERNEL_LOOP
(
i
,
(
end_key_index
-
start_key_index
))
{
Key
entry_key
=
table_keys
[
i
+
start_key_index
];
Index
entry_index
=
table_indices
[
i
+
start_key_index
];
if
(
entry_index
!=
0
)
{
bool
dump_flag
=
(
entry_index
!=
0
);
if
(
dump_dirty_only
)
{
bool
entry_dirty_flag
=
table_dirty_flags
[
i
+
start_key_index
];
dump_flag
=
(
dump_flag
&&
entry_dirty_flag
);
}
if
(
dump_flag
)
{
uint32_t
index
=
cuda
::
atomic
::
Add
(
n_dumped
,
static_cast
<
uint32_t
>
(
1
));
keys
[
index
]
=
((
entry_key
^
0x1
)
|
(
entry_index
&
0x1
));
context
[
index
]
=
(
entry_index
>>
1U
);
...
...
@@ -177,7 +196,11 @@ __global__ void EncodeLookupKernel(uint32_t value_length, const Elem* cache_valu
batch_start
+=
global_n_warp
*
warp_size
)
{
const
uint32_t
batch_n_key
=
min
(
n_keys
-
batch_start
,
warp_size
);
if
(
lane_id
==
0
)
{
batch_n_missing
[
warp_id
]
=
0
;
}
__syncwarp
();
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
const
uint32_t
key_offset
=
batch_start
+
lane_id
;
if
(
key_offset
<
n_keys
)
{
const
Key
key
=
keys
[
batch_start
+
lane_id
];
...
...
@@ -191,14 +214,22 @@ __global__ void EncodeLookupKernel(uint32_t value_length, const Elem* cache_valu
batch_missing_indices
[
warp_id
][
batch_missing_idx
]
=
key_offset
;
}
}
__syncwarp
();
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
const
uint32_t
batch_n_missing_t
=
batch_n_missing
[
warp_id
];
if
(
lane_id
==
0
)
{
const
uint32_t
old_n_missing
=
cuda
::
atomic
::
Add
(
n_missing
,
static_cast
<
uint32_t
>
(
batch_n_missing_t
));
batch_n_missing
[
warp_id
]
=
old_n_missing
;
}
__syncwarp
();
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
if
(
lane_id
<
batch_n_missing_t
)
{
missing_keys
[
batch_n_missing
[
warp_id
]
+
lane_id
]
=
batch_missing_keys
[
warp_id
][
lane_id
];
missing_indices
[
batch_n_missing
[
warp_id
]
+
lane_id
]
=
batch_missing_indices
[
warp_id
][
lane_id
];
...
...
@@ -212,7 +243,11 @@ __global__ void EncodeLookupKernel(uint32_t value_length, const Elem* cache_valu
cache_values
[(
row
-
1
)
*
value_length
+
col
];
}
}
__syncwarp
();
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
}
}
...
...
@@ -252,7 +287,11 @@ __global__ void EncodeLookupMaskKernel(uint32_t value_length, const Elem* __rest
batch_row_ids
[
warp_id
][
lane_id
]
=
row
;
mask
[
key_offset
]
=
row
>
0
;
}
__syncwarp
();
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
for
(
int
i
=
0
;
i
<
batch_n_key
;
++
i
)
{
const
Key
key
=
batch_keys
[
warp_id
][
i
];
const
Index
row
=
batch_row_ids
[
warp_id
][
i
];
...
...
@@ -263,7 +302,11 @@ __global__ void EncodeLookupMaskKernel(uint32_t value_length, const Elem* __rest
packed_cache_values
[(
row
-
1
)
*
packed_cols
+
col
];
}
}
__syncwarp
();
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
}
}
...
...
@@ -314,7 +357,7 @@ __global__ typename std::enable_if<!std::is_same<Elem, float>::value, void>::typ
FusedHalfUpdateKernel
(
uint32_t
value_length
,
Elem
*
cache_values
,
uint32_t
values_elem_cnt
,
const
Index
*
context
,
const
Elem
*
values
,
const
half
*
update
,
const
float
*
lr
,
float
scale
)
{
__trap
();
TRAP
();
}
template
<
typename
Key
,
typename
Elem
,
typename
Index
>
...
...
@@ -333,33 +376,39 @@ template<typename Key, typename Index>
class
OrdinalEncoder
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
OrdinalEncoder
);
explicit
OrdinalEncoder
(
uint64_t
capacity
,
float
load_factor
)
:
capacity_
(
capacity
),
table_capacity_
(
capacity
/
load_factor
)
{
OF_CUDA_CHECK
(
cudaGetDevice
(
&
device_index_
));
OF_CUDA_CHECK
(
cudaMalloc
(
&
table_size_
,
sizeof
(
Index
)));
explicit
OrdinalEncoder
(
uint64_t
capacity
,
float
load_factor
,
bool
if_dump_dirty
)
:
capacity_
(
capacity
),
table_capacity_
(
capacity
/
load_factor
),
if_dump_dirty_
(
if_dump_dirty
)
{
OF_CUDA_CHECK
(
GPU
(
GetDevice
)(
&
device_index_
));
OF_CUDA_CHECK
(
GPU
(
Malloc
)(
&
table_size_
,
sizeof
(
Index
)));
#ifdef WITH_ROCM
OF_CUDA_CHECK
(
hipMallocHost
(
reinterpret_cast
<
void
**>
(
&
table_size_host_
),
sizeof
(
Index
)));
#else
OF_CUDA_CHECK
(
cudaMallocHost
(
&
table_size_host_
,
sizeof
(
Index
)));
OF_CUDA_CHECK
(
cudaMalloc
(
&
table_keys_
,
table_capacity_
*
sizeof
(
Key
)));
OF_CUDA_CHECK
(
cudaMalloc
(
&
table_indices_
,
table_capacity_
*
sizeof
(
Index
)));
#endif
OF_CUDA_CHECK
(
GPU
(
Malloc
)(
&
table_keys_
,
table_capacity_
*
sizeof
(
Key
)));
OF_CUDA_CHECK
(
GPU
(
Malloc
)(
&
table_indices_
,
table_capacity_
*
sizeof
(
Index
)));
if
(
if_dump_dirty_
)
{
OF_CUDA_CHECK
(
GPU
(
Malloc
)(
&
table_dirty_flags_
,
table_capacity_
*
sizeof
(
bool
)));
}
Clear
();
}
~
OrdinalEncoder
()
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
OF_CUDA_CHECK
(
cudaFree
(
table_size_
));
OF_CUDA_CHECK
(
cudaFreeHost
(
table_size_host_
));
OF_CUDA_CHECK
(
cudaFree
(
table_keys_
));
OF_CUDA_CHECK
(
cudaFree
(
table_indices_
));
OF_CUDA_CHECK
(
GPU
(
Free
)(
table_size_
));
OF_CUDA_CHECK
(
GPU
(
FreeHost
)(
table_size_host_
));
OF_CUDA_CHECK
(
GPU
(
Free
)(
table_keys_
));
OF_CUDA_CHECK
(
GPU
(
Free
)(
table_indices_
));
if
(
if_dump_dirty_
)
{
OF_CUDA_CHECK
(
GPU
(
Free
)(
table_dirty_flags_
));
}
}
template
<
bool
insert
>
template
<
bool
insert
,
bool
dump_dirty_only
>
void
Encode
(
ep
::
Stream
*
stream
,
uint32_t
num_keys
,
const
Key
*
keys
,
Index
*
context
)
{
if
(
insert
)
{
RUN_CUDA_KERNEL
((
OrdinalEncodeKernel
<
Key
,
Index
>
),
stream
,
num_keys
,
table_capacity_
,
table_keys_
,
table_indices_
,
table_size_
,
num_keys
,
keys
,
context
);
OF_CUDA_CHECK
(
cudaMemcpyAsync
(
table_size_host_
,
table_size_
,
sizeof
(
Index
),
cudaMemcpyDefault
,
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()));
CHECK_JUST
(
stream
->
Sync
());
CHECK_LT
(
*
table_size_host_
,
capacity_
)
<<
"The number of key is larger than cache size, please enlarge cache_memory_budget. "
;
RUN_CUDA_KERNEL
((
OrdinalEncodeKernel
<
Key
,
Index
,
dump_dirty_only
>
),
stream
,
num_keys
,
table_capacity_
,
table_keys_
,
table_indices_
,
table_dirty_flags_
,
table_size_
,
num_keys
,
keys
,
context
);
}
else
{
RUN_CUDA_KERNEL
((
OrdinalEncodeLookupKernel
<
Key
,
Index
>
),
stream
,
num_keys
,
table_capacity_
,
table_keys_
,
table_indices_
,
num_keys
,
keys
,
context
);
...
...
@@ -368,17 +417,35 @@ class OrdinalEncoder {
void
Dump
(
ep
::
Stream
*
stream
,
uint64_t
start_key_index
,
uint64_t
end_key_index
,
uint32_t
*
n_dumped
,
Key
*
keys
,
Index
*
context
)
{
OF_CUDA_CHECK
(
cudaMemsetAsync
(
n_dumped
,
0
,
sizeof
(
uint32_t
),
OF_CUDA_CHECK
(
GPU
(
MemsetAsync
)(
n_dumped
,
0
,
sizeof
(
uint32_t
),
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()));
RUN_CUDA_KERNEL
((
OrdinalEncodeDumpKernel
<
Key
,
Index
,
false
>
),
stream
,
end_key_index
-
start_key_index
,
table_keys_
,
table_indices_
,
table_dirty_flags_
,
start_key_index
,
end_key_index
,
n_dumped
,
keys
,
context
);
}
void
DumpDirtyOnly
(
ep
::
Stream
*
stream
,
uint64_t
start_key_index
,
uint64_t
end_key_index
,
uint32_t
*
n_dumped
,
Key
*
keys
,
Index
*
context
)
{
OF_CUDA_CHECK
(
GPU
(
MemsetAsync
)(
n_dumped
,
0
,
sizeof
(
uint32_t
),
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()));
RUN_CUDA_KERNEL
((
OrdinalEncodeDumpKernel
<
Key
,
Index
>
),
stream
,
end_key_index
-
start_key_index
,
table_keys_
,
table_indices_
,
start_key_index
,
end_key_index
,
n_dumped
,
keys
,
context
);
RUN_CUDA_KERNEL
((
OrdinalEncodeDumpKernel
<
Key
,
Index
,
true
>
),
stream
,
end_key_index
-
start_key_index
,
table_keys_
,
table_indices_
,
table_dirty_flags_
,
start_key_index
,
end_key_index
,
n_dumped
,
keys
,
context
);
}
void
ClearDirtyFlags
()
{
if
(
if_dump_dirty_
)
{
OF_CUDA_CHECK
(
GPU
(
Memset
)(
table_dirty_flags_
,
0
,
table_capacity_
*
sizeof
(
bool
)));
}
}
void
Clear
()
{
OF_CUDA_CHECK
(
cudaMemset
(
table_size_
,
0
,
sizeof
(
Index
)));
OF_CUDA_CHECK
(
cudaMemset
(
table_keys_
,
0
,
table_capacity_
*
sizeof
(
Key
)));
OF_CUDA_CHECK
(
cudaMemset
(
table_indices_
,
0
,
table_capacity_
*
sizeof
(
Index
)));
OF_CUDA_CHECK
(
GPU
(
Memset
)(
table_size_
,
0
,
sizeof
(
Index
)));
OF_CUDA_CHECK
(
GPU
(
Memset
)(
table_keys_
,
0
,
table_capacity_
*
sizeof
(
Key
)));
OF_CUDA_CHECK
(
GPU
(
Memset
)(
table_indices_
,
0
,
table_capacity_
*
sizeof
(
Index
)));
if
(
if_dump_dirty_
)
{
OF_CUDA_CHECK
(
GPU
(
Memset
)(
table_dirty_flags_
,
0
,
table_capacity_
*
sizeof
(
bool
)));
}
}
uint64_t
TableCapacity
()
const
{
return
table_capacity_
;
}
...
...
@@ -391,8 +458,10 @@ class OrdinalEncoder {
int
device_index_
{};
Key
*
table_keys_
;
Index
*
table_indices_
;
bool
*
table_dirty_flags_
;
uint64_t
capacity_
;
uint64_t
table_capacity_
;
bool
if_dump_dirty_
;
Index
*
table_size_
{};
Index
*
table_size_host_
{};
};
...
...
@@ -402,17 +471,22 @@ class CacheImpl : public Cache {
public:
OF_DISALLOW_COPY_AND_MOVE
(
CacheImpl
);
explicit
CacheImpl
(
const
CacheOptions
&
options
)
:
encoder_
(
options
.
capacity
,
options
.
load_factor
),
:
if_dump_dirty_
(
ParseBooleanFromEnv
(
"ONEFLOW_ONE_EMBEDDING_DUMP_DIRTY_ONLY"
,
false
)),
encoder_
(
options
.
capacity
,
options
.
load_factor
,
if_dump_dirty_
),
device_index_
(
-
1
),
options_
(
options
),
max_query_length_
(
0
)
{
OF_CUDA_CHECK
(
cuda
GetDevice
(
&
device_index_
));
OF_CUDA_CHECK
(
GPU
(
GetDevice
)
(
&
device_index_
));
const
uint64_t
values_size
=
options
.
capacity
*
options
.
value_size
;
if
(
options
.
value_memory_kind
==
CacheOptions
::
MemoryKind
::
kDevice
)
{
OF_CUDA_CHECK
(
cuda
Malloc
(
&
values_
,
values_size
));
OF_CUDA_CHECK
(
GPU
(
Malloc
)
(
&
values_
,
values_size
));
}
else
if
(
options
.
value_memory_kind
==
CacheOptions
::
MemoryKind
::
kHost
)
{
if
(
ParseBooleanFromEnv
(
"ONEFLOW_ONE_EMBEDDING_DISABLE_NUMA_AWARE_ALLOCATION"
,
false
))
{
OF_CUDA_CHECK
(
cudaMallocHost
(
&
values_
,
values_size
));
#ifdef WITH_ROCM
OF_CUDA_CHECK
(
hipMallocHost
(
reinterpret_cast
<
void
**>
(
&
values_
),
values_size
));
#else
OF_CUDA_CHECK
(
cudaMallocHost
(
&
values_
,
values_size
));
#endif
}
else
{
OF_CUDA_CHECK
(
NumaAwareCudaMallocHost
(
device_index_
,
reinterpret_cast
<
void
**>
(
&
values_
),
values_size
));
...
...
@@ -425,13 +499,13 @@ class CacheImpl : public Cache {
~
CacheImpl
()
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
if
(
options_
.
value_memory_kind
==
CacheOptions
::
MemoryKind
::
kDevice
)
{
OF_CUDA_CHECK
(
cuda
Free
(
values_
));
OF_CUDA_CHECK
(
GPU
(
Free
)
(
values_
));
}
else
if
(
options_
.
value_memory_kind
==
CacheOptions
::
MemoryKind
::
kHost
)
{
OF_CUDA_CHECK
(
cuda
FreeHost
(
values_
));
OF_CUDA_CHECK
(
GPU
(
FreeHost
)
(
values_
));
}
else
{
UNIMPLEMENTED
();
}
if
(
max_query_length_
>
0
)
{
OF_CUDA_CHECK
(
cuda
Free
(
encoding_buffer_
));
}
if
(
max_query_length_
>
0
)
{
OF_CUDA_CHECK
(
GPU
(
Free
)
(
encoding_buffer_
));
}
}
uint64_t
Capacity
()
const
override
{
return
options_
.
capacity
;
}
...
...
@@ -447,8 +521,8 @@ class CacheImpl : public Cache {
void
ReserveQueryLength
(
uint32_t
query_length
)
override
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
if
(
query_length
<=
max_query_length_
)
{
return
;
}
if
(
max_query_length_
>
0
)
{
OF_CUDA_CHECK
(
cuda
Free
(
encoding_buffer_
));
}
OF_CUDA_CHECK
(
cuda
Malloc
(
&
encoding_buffer_
,
query_length
*
sizeof
(
uint64_t
)));
if
(
max_query_length_
>
0
)
{
OF_CUDA_CHECK
(
GPU
(
Free
)
(
encoding_buffer_
));
}
OF_CUDA_CHECK
(
GPU
(
Malloc
)
(
&
encoding_buffer_
,
query_length
*
sizeof
(
uint64_t
)));
max_query_length_
=
query_length
;
}
...
...
@@ -465,15 +539,19 @@ class CacheImpl : public Cache {
void
Put
(
ep
::
Stream
*
stream
,
uint32_t
n_keys
,
const
void
*
keys
,
const
void
*
values
,
uint32_t
*
n_evicted
,
void
*
evicted_keys
,
void
*
evicted_values
)
override
;
void
FusedHalfUpdatePut
(
ep
::
Stream
*
stream
,
uint32_t
n_keys
,
const
void
*
keys
,
const
void
*
values
,
const
void
*
update
,
const
float
*
lr
,
float
scale
,
uint32_t
*
n_evicted
,
void
*
evicted_keys
,
void
*
evicted_values
)
override
;
void
Dump
(
ep
::
Stream
*
stream
,
uint64_t
start_key_index
,
uint64_t
end_key_index
,
uint32_t
*
n_dumped
,
void
*
keys
,
void
*
values
)
override
;
void
ClearDirtyFlags
()
override
;
void
Clear
()
override
;
private:
bool
if_dump_dirty_
;
OrdinalEncoder
<
Key
,
Index
>
encoder_
;
int
device_index_
;
uint32_t
num_elem_per_value_
{};
...
...
@@ -488,10 +566,16 @@ void CacheImpl<Key, Elem, Index, pack_size>::Test(ep::Stream* stream, uint32_t n
const
void
*
keys
,
uint32_t
*
n_missing
,
void
*
missing_keys
,
uint32_t
*
missing_indices
)
{
OF_CUDA_CHECK
(
cuda
MemsetAsync
(
n_missing
,
0
,
sizeof
(
uint32_t
),
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()));
GPU
(
MemsetAsync
)
(
n_missing
,
0
,
sizeof
(
uint32_t
),
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()));
if
(
n_keys
==
0
)
{
return
;
}
CHECK_LE
(
n_keys
,
max_query_length_
);
encoder_
.
template
Encode
<
false
>(
stream
,
n_keys
,
static_cast
<
const
Key
*>
(
keys
),
encoding_buffer_
);
if
(
if_dump_dirty_
)
{
encoder_
.
template
Encode
<
false
,
true
>(
stream
,
n_keys
,
static_cast
<
const
Key
*>
(
keys
),
encoding_buffer_
);
}
else
{
encoder_
.
template
Encode
<
false
,
false
>(
stream
,
n_keys
,
static_cast
<
const
Key
*>
(
keys
),
encoding_buffer_
);
}
const
uint32_t
values_elem_cnt
=
n_keys
*
num_elem_per_value_
;
RUN_CUDA_KERNEL
((
LookupKernel
<
Key
,
Elem
,
Index
,
false
>
),
stream
,
values_elem_cnt
,
num_elem_per_value_
,
values_
,
values_elem_cnt
,
static_cast
<
const
Key
*>
(
keys
),
...
...
@@ -505,7 +589,7 @@ void CacheImpl<Key, Elem, Index, pack_size>::Get(ep::Stream* stream, uint32_t n_
uint32_t
*
n_missing
,
void
*
missing_keys
,
uint32_t
*
missing_indices
)
{
OF_CUDA_CHECK
(
cuda
MemsetAsync
(
n_missing
,
0
,
sizeof
(
uint32_t
),
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()));
GPU
(
MemsetAsync
)
(
n_missing
,
0
,
sizeof
(
uint32_t
),
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()));
if
(
n_keys
==
0
)
{
return
;
}
CHECK_LE
(
n_keys
,
max_query_length_
);
constexpr
uint32_t
block_size
=
128
;
...
...
@@ -539,11 +623,15 @@ void CacheImpl<Key, Elem, Index, pack_size>::Put(ep::Stream* stream, uint32_t n_
const
void
*
keys
,
const
void
*
values
,
uint32_t
*
n_evicted
,
void
*
evicted_keys
,
void
*
evicted_values
)
{
OF_CUDA_CHECK
(
cudaMemsetAsync
(
n_evicted
,
0
,
sizeof
(
uint32_t
),
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()));
if
(
n_keys
==
0
)
{
return
;
}
CHECK_LE
(
n_keys
,
max_query_length_
);
encoder_
.
template
Encode
<
true
>(
stream
,
n_keys
,
static_cast
<
const
Key
*>
(
keys
),
encoding_buffer_
);
if
(
if_dump_dirty_
)
{
encoder_
.
template
Encode
<
true
,
true
>(
stream
,
n_keys
,
static_cast
<
const
Key
*>
(
keys
),
encoding_buffer_
);
}
else
{
encoder_
.
template
Encode
<
true
,
false
>(
stream
,
n_keys
,
static_cast
<
const
Key
*>
(
keys
),
encoding_buffer_
);
}
const
uint32_t
values_elem_cnt
=
n_keys
*
num_elem_per_value_
;
RUN_CUDA_KERNEL
((
UpdateKernel
<
Elem
,
Index
,
pack_size
>
),
stream
,
values_elem_cnt
/
pack_size
,
num_elem_per_value_
,
values_
,
values_elem_cnt
,
encoding_buffer_
,
...
...
@@ -555,28 +643,43 @@ void CacheImpl<Key, Elem, Index, pack_size>::FusedHalfUpdatePut(
ep
::
Stream
*
stream
,
uint32_t
n_keys
,
const
void
*
keys
,
const
void
*
values
,
const
void
*
update
,
const
float
*
lr
,
float
scale
,
uint32_t
*
n_evicted
,
void
*
evicted_keys
,
void
*
evicted_values
)
{
if
(
!
std
::
is_same
<
Elem
,
float
>::
value
)
{
UNIMPLEMENTED
();
}
OF_CUDA_CHECK
(
cudaMemsetAsync
(
n_evicted
,
0
,
sizeof
(
uint32_t
),
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()));
if
(
n_keys
==
0
)
{
return
;
}
CHECK_LE
(
n_keys
,
max_query_length_
);
encoder_
.
template
Encode
<
true
>(
stream
,
n_keys
,
static_cast
<
const
Key
*>
(
keys
),
encoding_buffer_
);
if
(
if_dump_dirty_
)
{
encoder_
.
template
Encode
<
true
,
true
>(
stream
,
n_keys
,
static_cast
<
const
Key
*>
(
keys
),
encoding_buffer_
);
}
else
{
encoder_
.
template
Encode
<
true
,
false
>(
stream
,
n_keys
,
static_cast
<
const
Key
*>
(
keys
),
encoding_buffer_
);
}
const
uint32_t
values_elem_cnt
=
n_keys
*
num_elem_per_value_
;
RUN_CUDA_KERNEL
((
FusedHalfUpdateKernel
<
Elem
,
Index
,
pack_size
>
),
stream
,
values_elem_cnt
/
pack_size
,
num_elem_per_value_
,
values_
,
values_elem_cnt
,
encoding_buffer_
,
static_cast
<
const
Elem
*>
(
values
),
static_cast
<
const
half
*>
(
update
),
lr
,
scale
);
}
template
<
typename
Key
,
typename
Elem
,
typename
Index
,
size_t
pack_size
>
void
CacheImpl
<
Key
,
Elem
,
Index
,
pack_size
>::
Dump
(
ep
::
Stream
*
stream
,
uint64_t
start_key_index
,
uint64_t
end_key_index
,
uint32_t
*
n_dumped
,
void
*
keys
,
void
*
values
)
{
encoder_
.
Dump
(
stream
,
start_key_index
,
end_key_index
,
n_dumped
,
static_cast
<
Key
*>
(
keys
),
encoding_buffer_
);
if
(
if_dump_dirty_
)
{
encoder_
.
DumpDirtyOnly
(
stream
,
start_key_index
,
end_key_index
,
n_dumped
,
static_cast
<
Key
*>
(
keys
),
encoding_buffer_
);
}
else
{
encoder_
.
Dump
(
stream
,
start_key_index
,
end_key_index
,
n_dumped
,
static_cast
<
Key
*>
(
keys
),
encoding_buffer_
);
}
RUN_CUDA_KERNEL
((
DumpValueKernel
<
Key
,
Elem
,
Index
>
),
stream
,
num_elem_per_value_
*
(
end_key_index
-
start_key_index
),
num_elem_per_value_
,
n_dumped
,
encoding_buffer_
,
values_
,
static_cast
<
Elem
*>
(
values
));
}
template
<
typename
Key
,
typename
Elem
,
typename
Index
,
size_t
pack_size
>
void
CacheImpl
<
Key
,
Elem
,
Index
,
pack_size
>::
ClearDirtyFlags
()
{
encoder_
.
ClearDirtyFlags
();
}
template
<
typename
Key
,
typename
Elem
,
typename
Index
,
size_t
pack_size
>
void
CacheImpl
<
Key
,
Elem
,
Index
,
pack_size
>::
Clear
()
{
encoder_
.
Clear
();
...
...
oneflow/core/embedding/full_cache.h
View file @
a715222c
...
...
@@ -23,16 +23,11 @@ namespace oneflow {
namespace
embedding
{
#ifdef
WITH_CUDA
#if
def
ined(
WITH_CUDA
) || defined(WITH_ROCM)
std
::
unique_ptr
<
Cache
>
NewFullCache
(
const
CacheOptions
&
options
);
#endif // WITH_CUDA
#ifdef WITH_ROCM
std
::
unique_ptr
<
Cache
>
NewFullCache
(
const
CacheOptions
&
options
);
#endif // WITH_ROCM
}
// namespace embedding
...
...
oneflow/core/embedding/full_cache.hip.cpp
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "hip/hip_runtime.h"
#include "oneflow/core/embedding/full_cache.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/embedding/hash_functions.hip.h"
#include "oneflow/core/hip/atomic.hip.h"
namespace
oneflow
{
namespace
embedding
{
using
Key32
=
unsigned
int
;
using
Key64
=
unsigned
long
long
int
;
using
Key128
=
ulonglong2
;
namespace
{
template
<
typename
Key
,
typename
Index
>
__device__
bool
TryGetOrInsert
(
Key
*
entry_key
,
volatile
Index
*
entry_index
,
Index
*
table_size
,
Key
key
,
Index
*
out
)
{
Key
key_hi
=
(
key
|
0x1
);
Key
key_lo
=
(
key
&
0x1
);
Index
index_plus_one
=
0
;
Key
old_entry_key
=
cuda
::
atomic
::
CAS
(
entry_key
,
static_cast
<
Key
>
(
0
),
key_hi
);
while
(
index_plus_one
==
0
)
{
if
(
old_entry_key
==
static_cast
<
Key
>
(
0
))
{
Index
index
=
cuda
::
atomic
::
Add
(
table_size
,
static_cast
<
Index
>
(
1
));
index_plus_one
=
index
+
1
;
*
entry_index
=
((
index_plus_one
<<
1U
)
|
key_lo
);
*
out
=
index_plus_one
;
return
true
;
}
else
if
(
old_entry_key
==
key_hi
)
{
const
Index
entry_index_val
=
*
entry_index
;
if
(
entry_index_val
==
0
)
{
// do nothing
}
else
if
((
entry_index_val
&
0x1
)
==
key_lo
)
{
*
out
=
(
entry_index_val
>>
1U
);
return
true
;
}
else
{
return
false
;
}
}
else
{
return
false
;
}
}
return
false
;
}
template
<
typename
Key
,
typename
Index
>
__device__
bool
GetOrInsertOne
(
const
size_t
capacity
,
Key
*
table_keys
,
Index
*
table_indices
,
Index
*
table_size
,
Key
key
,
size_t
hash
,
Index
*
out
)
{
const
size_t
start_idx
=
hash
%
capacity
;
for
(
size_t
count
=
0
;
count
<
capacity
;
++
count
)
{
const
size_t
idx
=
(
start_idx
+
count
)
%
capacity
;
Key
*
entry_key
=
table_keys
+
idx
;
Index
*
entry_index
=
table_indices
+
idx
;
if
(
TryGetOrInsert
<
Key
,
Index
>
(
entry_key
,
entry_index
,
table_size
,
key
,
out
))
{
return
true
;
}
}
return
false
;
}
template
<
typename
Key
,
typename
Index
>
__device__
bool
GetOne
(
const
size_t
capacity
,
Key
*
table_keys
,
Index
*
table_indices
,
Key
key
,
size_t
hash
,
Index
*
out
)
{
const
size_t
start_idx
=
hash
%
capacity
;
for
(
size_t
count
=
0
;
count
<
capacity
;
++
count
)
{
const
size_t
idx
=
(
start_idx
+
count
)
%
capacity
;
Key
entry_key
=
table_keys
[
idx
];
Key
entry_index
=
table_indices
[
idx
];
Key
key_hi
=
(
key
|
0x1
);
Key
key_lo
=
(
key
&
0x1
);
if
(
entry_key
==
0
)
{
break
;
}
if
(
entry_key
==
key_hi
)
{
if
((
entry_index
&
0x1
)
==
key_lo
)
{
*
out
=
(
entry_index
>>
1U
);
return
true
;
}
}
}
*
out
=
0
;
return
false
;
}
template
<
typename
Key
,
typename
Index
>
__global__
void
OrdinalEncodeKernel
(
uint64_t
capacity
,
Key
*
table_keys
,
Index
*
table_indices
,
Index
*
table_size
,
uint32_t
num_keys
,
const
Key
*
keys
,
Index
*
context
)
{
CUDA_1D_KERNEL_LOOP
(
i
,
num_keys
)
{
Key
key
=
keys
[
i
];
uint64_t
hash
=
FullCacheHash
()(
key
);
bool
success
=
GetOrInsertOne
<
Key
,
Index
>
(
capacity
,
table_keys
,
table_indices
,
table_size
,
key
,
hash
,
context
+
i
);
assert
(
success
);
}
}
template
<
typename
Key
,
typename
Index
>
__global__
void
OrdinalEncodeLookupKernel
(
uint64_t
capacity
,
Key
*
table_keys
,
Index
*
table_indices
,
uint32_t
num_keys
,
const
Key
*
keys
,
Index
*
context
)
{
CUDA_1D_KERNEL_LOOP
(
i
,
num_keys
)
{
Key
key
=
keys
[
i
];
uint64_t
hash
=
FullCacheHash
()(
key
);
GetOne
<
Key
,
Index
>
(
capacity
,
table_keys
,
table_indices
,
key
,
hash
,
context
+
i
);
}
}
template
<
typename
Key
,
typename
Index
>
__global__
void
OrdinalEncodeDumpKernel
(
const
Key
*
table_keys
,
const
Index
*
table_indices
,
uint64_t
start_key_index
,
uint64_t
end_key_index
,
uint32_t
*
n_dumped
,
Key
*
keys
,
Index
*
context
)
{
CUDA_1D_KERNEL_LOOP
(
i
,
(
end_key_index
-
start_key_index
))
{
Key
entry_key
=
table_keys
[
i
+
start_key_index
];
Index
entry_index
=
table_indices
[
i
+
start_key_index
];
if
(
entry_index
!=
0
)
{
uint32_t
index
=
cuda
::
atomic
::
Add
(
n_dumped
,
static_cast
<
uint32_t
>
(
1
));
keys
[
index
]
=
((
entry_key
^
0x1
)
|
(
entry_index
&
0x1
));
context
[
index
]
=
(
entry_index
>>
1U
);
}
}
}
template
<
typename
Key
,
typename
Elem
,
typename
Index
,
bool
return_value
>
__global__
void
LookupKernel
(
uint32_t
value_length
,
const
Elem
*
cache_values
,
uint32_t
values_elem_cnt
,
const
Key
*
keys
,
const
Index
*
context
,
Elem
*
values
,
uint32_t
*
n_missing
,
Key
*
missing_keys
,
uint32_t
*
missing_indices
)
{
CUDA_1D_KERNEL_LOOP
(
i
,
values_elem_cnt
)
{
const
uint64_t
key_id
=
i
/
value_length
;
const
uint64_t
ctx
=
context
[
key_id
];
const
uint64_t
row_id
=
ctx
-
1
;
const
uint64_t
col_id
=
i
-
key_id
*
value_length
;
if
(
ctx
==
0
)
{
const
Key
missing_key
=
keys
[
key_id
];
if
(
col_id
==
0
)
{
const
uint32_t
old_n_missing
=
cuda
::
atomic
::
Add
(
n_missing
,
static_cast
<
uint32_t
>
(
1
));
missing_keys
[
old_n_missing
]
=
missing_key
;
missing_indices
[
old_n_missing
]
=
key_id
;
}
continue
;
}
if
(
return_value
)
{
values
[
i
]
=
cache_values
[
row_id
*
value_length
+
col_id
];
}
}
}
template
<
typename
Key
,
typename
Elem
,
typename
Index
,
uint32_t
block_size
>
__global__
void
EncodeLookupKernel
(
uint32_t
value_length
,
const
Elem
*
cache_values
,
uint32_t
values_elem_cnt
,
const
Key
*
keys
,
const
Index
*
context
,
Elem
*
values
,
uint32_t
*
n_missing
,
Key
*
missing_keys
,
uint32_t
*
missing_indices
,
const
size_t
capacity
,
Key
*
table_keys
,
Index
*
table_indices
)
{
constexpr
uint32_t
warp_size
=
32
;
constexpr
uint32_t
n_warp_per_block
=
block_size
/
warp_size
;
const
uint32_t
warp_id
=
threadIdx
.
x
/
warp_size
;
const
uint32_t
lane_id
=
threadIdx
.
x
%
warp_size
;
const
uint32_t
global_warp_id
=
blockIdx
.
x
*
n_warp_per_block
+
warp_id
;
const
uint32_t
global_n_warp
=
gridDim
.
x
*
n_warp_per_block
;
const
uint32_t
n_keys
=
values_elem_cnt
/
value_length
;
__shared__
Key
batch_keys
[
n_warp_per_block
][
warp_size
];
__shared__
Index
batch_row_ids
[
n_warp_per_block
][
warp_size
];
__shared__
Key
batch_missing_keys
[
n_warp_per_block
][
warp_size
];
__shared__
uint32_t
batch_missing_indices
[
n_warp_per_block
][
warp_size
];
__shared__
uint32_t
batch_n_missing
[
n_warp_per_block
];
for
(
uint32_t
batch_start
=
global_warp_id
*
warp_size
;
batch_start
<
n_keys
;
batch_start
+=
global_n_warp
*
warp_size
)
{
const
uint32_t
batch_n_key
=
min
(
n_keys
-
batch_start
,
warp_size
);
if
(
lane_id
==
0
)
{
batch_n_missing
[
warp_id
]
=
0
;
}
__syncthreads
();
const
uint32_t
key_offset
=
batch_start
+
lane_id
;
if
(
key_offset
<
n_keys
)
{
const
Key
key
=
keys
[
batch_start
+
lane_id
];
const
uint64_t
hash
=
FullCacheHash
()(
key
);
Index
row
;
GetOne
<
Key
,
Index
>
(
capacity
,
table_keys
,
table_indices
,
key
,
hash
,
&
row
);
batch_row_ids
[
warp_id
][
lane_id
]
=
row
;
if
(
row
==
0
)
{
const
uint32_t
batch_missing_idx
=
atomicAdd
(
batch_n_missing
+
warp_id
,
1
);
batch_missing_keys
[
warp_id
][
batch_missing_idx
]
=
key
;
batch_missing_indices
[
warp_id
][
batch_missing_idx
]
=
key_offset
;
}
}
__syncthreads
();
const
uint32_t
batch_n_missing_t
=
batch_n_missing
[
warp_id
];
if
(
lane_id
==
0
)
{
const
uint32_t
old_n_missing
=
cuda
::
atomic
::
Add
(
n_missing
,
static_cast
<
uint32_t
>
(
batch_n_missing_t
));
batch_n_missing
[
warp_id
]
=
old_n_missing
;
}
__syncthreads
();
if
(
lane_id
<
batch_n_missing_t
)
{
missing_keys
[
batch_n_missing
[
warp_id
]
+
lane_id
]
=
batch_missing_keys
[
warp_id
][
lane_id
];
missing_indices
[
batch_n_missing
[
warp_id
]
+
lane_id
]
=
batch_missing_indices
[
warp_id
][
lane_id
];
}
for
(
int
i
=
0
;
i
<
batch_n_key
;
++
i
)
{
const
Key
key
=
batch_keys
[
warp_id
][
i
];
const
Index
row
=
batch_row_ids
[
warp_id
][
i
];
if
(
row
==
0
)
{
continue
;
}
for
(
int
col
=
lane_id
;
col
<
value_length
;
col
+=
warp_size
)
{
values
[(
batch_start
+
i
)
*
value_length
+
col
]
=
cache_values
[(
row
-
1
)
*
value_length
+
col
];
}
}
__syncthreads
();
}
}
template
<
typename
T
,
size_t
pack_size
>
struct
alignas
(
sizeof
(
T
)
*
pack_size
)
Pack
{
T
elem
[
pack_size
];
};
template
<
typename
Key
,
typename
Elem
,
typename
Index
,
uint32_t
block_size
,
uint32_t
pack_size
>
__global__
void
EncodeLookupMaskKernel
(
uint32_t
value_length
,
const
Elem
*
__restrict__
cache_values
,
uint32_t
values_elem_cnt
,
const
Key
*
__restrict__
keys
,
const
Index
*
__restrict__
context
,
Elem
*
__restrict__
values
,
uint8_t
*
__restrict__
mask
,
const
size_t
capacity
,
Key
*
__restrict__
table_keys
,
Index
*
__restrict__
table_indices
)
{
const
uint32_t
packed_cols
=
value_length
/
pack_size
;
auto
*
packed_values
=
reinterpret_cast
<
Pack
<
Elem
,
pack_size
>*>
(
values
);
const
auto
*
packed_cache_values
=
reinterpret_cast
<
const
Pack
<
Elem
,
pack_size
>*>
(
cache_values
);
constexpr
uint32_t
warp_size
=
32
;
constexpr
uint32_t
n_warp_per_block
=
block_size
/
warp_size
;
const
uint32_t
warp_id
=
threadIdx
.
x
/
warp_size
;
const
uint32_t
lane_id
=
threadIdx
.
x
%
warp_size
;
const
uint32_t
global_warp_id
=
blockIdx
.
x
*
n_warp_per_block
+
warp_id
;
const
uint32_t
global_n_warp
=
gridDim
.
x
*
n_warp_per_block
;
const
uint32_t
n_keys
=
values_elem_cnt
/
value_length
;
__shared__
Key
batch_keys
[
n_warp_per_block
][
warp_size
];
__shared__
Index
batch_row_ids
[
n_warp_per_block
][
warp_size
];
for
(
uint32_t
batch_start
=
global_warp_id
*
warp_size
;
batch_start
<
n_keys
;
batch_start
+=
global_n_warp
*
warp_size
)
{
const
uint32_t
batch_n_key
=
min
(
n_keys
-
batch_start
,
warp_size
);
const
uint32_t
key_offset
=
batch_start
+
lane_id
;
if
(
key_offset
<
n_keys
)
{
const
Key
key
=
keys
[
batch_start
+
lane_id
];
const
uint64_t
hash
=
FullCacheHash
()(
key
);
Index
row
;
GetOne
<
Key
,
Index
>
(
capacity
,
table_keys
,
table_indices
,
key
,
hash
,
&
row
);
batch_row_ids
[
warp_id
][
lane_id
]
=
row
;
mask
[
key_offset
]
=
row
>
0
;
}
__syncthreads
();
for
(
int
i
=
0
;
i
<
batch_n_key
;
++
i
)
{
const
Key
key
=
batch_keys
[
warp_id
][
i
];
const
Index
row
=
batch_row_ids
[
warp_id
][
i
];
if
(
row
==
0
)
{
continue
;
}
#pragma unroll 4
for
(
int
col
=
lane_id
;
col
<
packed_cols
;
col
+=
warp_size
)
{
packed_values
[(
batch_start
+
i
)
*
packed_cols
+
col
]
=
packed_cache_values
[(
row
-
1
)
*
packed_cols
+
col
];
}
}
__syncthreads
();
}
}
template
<
typename
Elem
,
typename
Index
,
size_t
pack_size
>
__global__
void
UpdateKernel
(
uint32_t
value_length
,
Elem
*
cache_values
,
uint32_t
values_elem_cnt
,
const
Index
*
context
,
const
Elem
*
values
)
{
const
int
packed_values_elem_cnt
=
values_elem_cnt
/
pack_size
;
const
uint32_t
packed_elem_cnt
=
value_length
/
pack_size
;
auto
*
packed_cache_values
=
reinterpret_cast
<
Pack
<
Elem
,
pack_size
>*>
(
cache_values
);
auto
*
packed_values
=
reinterpret_cast
<
const
Pack
<
Elem
,
pack_size
>*>
(
values
);
CUDA_1D_KERNEL_LOOP
(
i
,
packed_values_elem_cnt
)
{
const
uint64_t
key_id
=
i
/
packed_elem_cnt
;
const
uint64_t
ctx
=
context
[
key_id
];
if
(
ctx
==
0
)
{
continue
;
}
const
uint64_t
row_id
=
ctx
-
1
;
const
uint64_t
col_id
=
i
-
key_id
*
packed_elem_cnt
;
packed_cache_values
[
row_id
*
packed_elem_cnt
+
col_id
]
=
packed_values
[
i
];
}
}
template
<
typename
Elem
,
typename
Index
,
size_t
pack_size
>
__global__
typename
std
::
enable_if
<
std
::
is_same
<
Elem
,
float
>::
value
,
void
>::
type
FusedHalfUpdateKernel
(
uint32_t
value_length
,
Elem
*
__restrict__
cache_values
,
uint32_t
values_elem_cnt
,
const
Index
*
__restrict__
context
,
const
Elem
*
__restrict__
values
,
const
half
*
__restrict__
update
,
const
float
*
__restrict__
lr
,
float
scale
)
{
const
int
packed_values_elem_cnt
=
values_elem_cnt
/
pack_size
;
const
uint32_t
packed_elem_cnt
=
value_length
/
pack_size
;
auto
*
packed_cache_values
=
reinterpret_cast
<
Pack
<
Elem
,
pack_size
>*>
(
cache_values
);
auto
*
packed_values
=
reinterpret_cast
<
const
Pack
<
Elem
,
pack_size
>*>
(
values
);
auto
*
packed_update
=
reinterpret_cast
<
const
Pack
<
half
,
pack_size
>*>
(
update
);
const
float
alpha
=
-*
lr
*
scale
;
CUDA_1D_KERNEL_LOOP
(
i
,
packed_values_elem_cnt
)
{
const
uint64_t
key_id
=
i
/
packed_elem_cnt
;
const
uint64_t
ctx
=
context
[
key_id
];
if
(
ctx
==
0
)
{
continue
;
}
const
uint64_t
row_id
=
ctx
-
1
;
const
uint64_t
col_id
=
i
-
key_id
*
packed_elem_cnt
;
Pack
<
Elem
,
pack_size
>
m
=
packed_values
[
i
];
Pack
<
half
,
pack_size
>
u
=
packed_update
[
i
];
for
(
size_t
j
=
0
;
j
<
pack_size
;
++
j
)
{
m
.
elem
[
j
]
+=
static_cast
<
Elem
>
(
u
.
elem
[
j
])
*
alpha
;
}
packed_cache_values
[
row_id
*
packed_elem_cnt
+
col_id
]
=
m
;
}
}
template
<
typename
Elem
,
typename
Index
,
size_t
pack_size
>
__global__
typename
std
::
enable_if
<!
std
::
is_same
<
Elem
,
float
>::
value
,
void
>::
type
FusedHalfUpdateKernel
(
uint32_t
value_length
,
Elem
*
cache_values
,
uint32_t
values_elem_cnt
,
const
Index
*
context
,
const
Elem
*
values
,
const
half
*
update
,
const
float
*
lr
,
float
scale
)
{
asm
volatile
(
"s_trap 0;"
);
}
template
<
typename
Key
,
typename
Elem
,
typename
Index
>
__global__
void
DumpValueKernel
(
uint32_t
value_length
,
const
uint32_t
*
n_dumped
,
const
Index
*
context
,
const
Elem
*
cache_values
,
Elem
*
values
)
{
CUDA_1D_KERNEL_LOOP
(
i
,
*
n_dumped
*
value_length
)
{
const
uint64_t
key_id
=
i
/
value_length
;
const
uint64_t
ctx
=
context
[
key_id
];
const
uint64_t
row_id
=
ctx
-
1
;
const
uint64_t
col_id
=
i
-
key_id
*
value_length
;
values
[
i
]
=
cache_values
[
row_id
*
value_length
+
col_id
];
}
}
template
<
typename
Key
,
typename
Index
>
class
OrdinalEncoder
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
OrdinalEncoder
);
explicit
OrdinalEncoder
(
uint64_t
capacity
,
float
load_factor
)
:
capacity_
(
capacity
),
table_capacity_
(
capacity
/
load_factor
)
{
OF_CUDA_CHECK
(
hipGetDevice
(
&
device_index_
));
OF_CUDA_CHECK
(
hipMalloc
(
&
table_size_
,
sizeof
(
Index
)));
OF_CUDA_CHECK
(
hipMallocHost
(
reinterpret_cast
<
void
**>
(
&
table_size_host_
),
sizeof
(
Index
)));
OF_CUDA_CHECK
(
hipMalloc
(
&
table_keys_
,
table_capacity_
*
sizeof
(
Key
)));
OF_CUDA_CHECK
(
hipMalloc
(
&
table_indices_
,
table_capacity_
*
sizeof
(
Index
)));
Clear
();
}
~
OrdinalEncoder
()
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
OF_CUDA_CHECK
(
hipFree
(
table_size_
));
OF_CUDA_CHECK
(
hipHostFree
(
table_size_host_
));
OF_CUDA_CHECK
(
hipFree
(
table_keys_
));
OF_CUDA_CHECK
(
hipFree
(
table_indices_
));
}
template
<
bool
insert
>
void
Encode
(
ep
::
Stream
*
stream
,
uint32_t
num_keys
,
const
Key
*
keys
,
Index
*
context
)
{
if
(
insert
)
{
RUN_CUDA_KERNEL
((
OrdinalEncodeKernel
<
Key
,
Index
>
),
stream
,
num_keys
,
table_capacity_
,
table_keys_
,
table_indices_
,
table_size_
,
num_keys
,
keys
,
context
);
OF_CUDA_CHECK
(
hipMemcpyAsync
(
table_size_host_
,
table_size_
,
sizeof
(
Index
),
hipMemcpyDefault
,
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()));
CHECK_JUST
(
stream
->
Sync
());
CHECK_LT
(
*
table_size_host_
,
capacity_
)
<<
"The number of key is larger than cache size, please enlarge cache_memory_budget. "
;
}
else
{
RUN_CUDA_KERNEL
((
OrdinalEncodeLookupKernel
<
Key
,
Index
>
),
stream
,
num_keys
,
table_capacity_
,
table_keys_
,
table_indices_
,
num_keys
,
keys
,
context
);
}
}
void
Dump
(
ep
::
Stream
*
stream
,
uint64_t
start_key_index
,
uint64_t
end_key_index
,
uint32_t
*
n_dumped
,
Key
*
keys
,
Index
*
context
)
{
OF_CUDA_CHECK
(
hipMemsetAsync
(
n_dumped
,
0
,
sizeof
(
uint32_t
),
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()));
RUN_CUDA_KERNEL
((
OrdinalEncodeDumpKernel
<
Key
,
Index
>
),
stream
,
end_key_index
-
start_key_index
,
table_keys_
,
table_indices_
,
start_key_index
,
end_key_index
,
n_dumped
,
keys
,
context
);
}
void
Clear
()
{
OF_CUDA_CHECK
(
hipMemset
(
table_size_
,
0
,
sizeof
(
Index
)));
OF_CUDA_CHECK
(
hipMemset
(
table_keys_
,
0
,
table_capacity_
*
sizeof
(
Key
)));
OF_CUDA_CHECK
(
hipMemset
(
table_indices_
,
0
,
table_capacity_
*
sizeof
(
Index
)));
}
uint64_t
TableCapacity
()
const
{
return
table_capacity_
;
}
Key
*
table_keys
()
const
{
return
table_keys_
;
}
Index
*
table_indices
()
const
{
return
table_indices_
;
}
private:
int
device_index_
{};
Key
*
table_keys_
;
Index
*
table_indices_
;
uint64_t
capacity_
;
uint64_t
table_capacity_
;
Index
*
table_size_
{};
Index
*
table_size_host_
{};
};
template
<
typename
Key
,
typename
Elem
,
typename
Index
,
size_t
pack_size
>
class
CacheImpl
:
public
Cache
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
CacheImpl
);
explicit
CacheImpl
(
const
CacheOptions
&
options
)
:
encoder_
(
options
.
capacity
,
options
.
load_factor
),
device_index_
(
-
1
),
options_
(
options
),
max_query_length_
(
0
)
{
OF_CUDA_CHECK
(
hipGetDevice
(
&
device_index_
));
const
uint64_t
values_size
=
options
.
capacity
*
options
.
value_size
;
if
(
options
.
value_memory_kind
==
CacheOptions
::
MemoryKind
::
kDevice
)
{
OF_CUDA_CHECK
(
hipMalloc
(
&
values_
,
values_size
));
}
else
if
(
options
.
value_memory_kind
==
CacheOptions
::
MemoryKind
::
kHost
)
{
if
(
ParseBooleanFromEnv
(
"ONEFLOW_ONE_EMBEDDING_DISABLE_NUMA_AWARE_ALLOCATION"
,
false
))
{
OF_CUDA_CHECK
(
hipMallocHost
(
reinterpret_cast
<
void
**>
(
&
values_
),
values_size
));
}
else
{
OF_CUDA_CHECK
(
NumaAwareCudaMallocHost
(
device_index_
,
reinterpret_cast
<
void
**>
(
&
values_
),
values_size
));
}
}
else
{
UNIMPLEMENTED
();
}
num_elem_per_value_
=
options_
.
value_size
/
sizeof
(
Elem
);
}
~
CacheImpl
()
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
if
(
options_
.
value_memory_kind
==
CacheOptions
::
MemoryKind
::
kDevice
)
{
OF_CUDA_CHECK
(
hipFree
(
values_
));
}
else
if
(
options_
.
value_memory_kind
==
CacheOptions
::
MemoryKind
::
kHost
)
{
OF_CUDA_CHECK
(
hipHostFree
(
values_
));
}
else
{
UNIMPLEMENTED
();
}
if
(
max_query_length_
>
0
)
{
OF_CUDA_CHECK
(
hipFree
(
encoding_buffer_
));
}
}
uint64_t
Capacity
()
const
override
{
return
options_
.
capacity
;
}
uint64_t
DumpCapacity
()
const
override
{
return
encoder_
.
TableCapacity
();
}
uint32_t
KeySize
()
const
override
{
return
options_
.
key_size
;
}
uint32_t
ValueSize
()
const
override
{
return
options_
.
value_size
;
}
DataType
ValueType
()
const
override
{
return
options_
.
value_type
;
}
uint32_t
MaxQueryLength
()
const
override
{
return
max_query_length_
;
}
void
ReserveQueryLength
(
uint32_t
query_length
)
override
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
if
(
query_length
<=
max_query_length_
)
{
return
;
}
if
(
max_query_length_
>
0
)
{
OF_CUDA_CHECK
(
hipFree
(
encoding_buffer_
));
}
OF_CUDA_CHECK
(
hipMalloc
(
&
encoding_buffer_
,
query_length
*
sizeof
(
uint64_t
)));
max_query_length_
=
query_length
;
}
CacheOptions
::
Policy
Policy
()
const
override
{
return
CacheOptions
::
Policy
::
kFull
;
}
void
Test
(
ep
::
Stream
*
stream
,
uint32_t
n_keys
,
const
void
*
keys
,
uint32_t
*
n_missing
,
void
*
missing_keys
,
uint32_t
*
missing_indices
)
override
;
void
Get
(
ep
::
Stream
*
stream
,
uint32_t
n_keys
,
const
void
*
keys
,
void
*
values
,
uint32_t
*
n_missing
,
void
*
missing_keys
,
uint32_t
*
missing_indices
)
override
;
void
Get
(
ep
::
Stream
*
stream
,
uint32_t
n_keys
,
const
void
*
keys
,
void
*
values
,
uint8_t
*
mask
)
override
;
void
Put
(
ep
::
Stream
*
stream
,
uint32_t
n_keys
,
const
void
*
keys
,
const
void
*
values
,
uint32_t
*
n_evicted
,
void
*
evicted_keys
,
void
*
evicted_values
)
override
;
void
FusedHalfUpdatePut
(
ep
::
Stream
*
stream
,
uint32_t
n_keys
,
const
void
*
keys
,
const
void
*
values
,
const
void
*
update
,
const
float
*
lr
,
float
scale
,
uint32_t
*
n_evicted
,
void
*
evicted_keys
,
void
*
evicted_values
)
override
;
void
Dump
(
ep
::
Stream
*
stream
,
uint64_t
start_key_index
,
uint64_t
end_key_index
,
uint32_t
*
n_dumped
,
void
*
keys
,
void
*
values
)
override
;
void
Clear
()
override
;
private:
OrdinalEncoder
<
Key
,
Index
>
encoder_
;
int
device_index_
;
uint32_t
num_elem_per_value_
{};
Elem
*
values_
;
Index
*
encoding_buffer_
{};
CacheOptions
options_
;
uint32_t
max_query_length_
;
};
template
<
typename
Key
,
typename
Elem
,
typename
Index
,
size_t
pack_size
>
void
CacheImpl
<
Key
,
Elem
,
Index
,
pack_size
>::
Test
(
ep
::
Stream
*
stream
,
uint32_t
n_keys
,
const
void
*
keys
,
uint32_t
*
n_missing
,
void
*
missing_keys
,
uint32_t
*
missing_indices
)
{
OF_CUDA_CHECK
(
hipMemsetAsync
(
n_missing
,
0
,
sizeof
(
uint32_t
),
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()));
if
(
n_keys
==
0
)
{
return
;
}
CHECK_LE
(
n_keys
,
max_query_length_
);
encoder_
.
template
Encode
<
false
>(
stream
,
n_keys
,
static_cast
<
const
Key
*>
(
keys
),
encoding_buffer_
);
const
uint32_t
values_elem_cnt
=
n_keys
*
num_elem_per_value_
;
RUN_CUDA_KERNEL
((
LookupKernel
<
Key
,
Elem
,
Index
,
false
>
),
stream
,
values_elem_cnt
,
num_elem_per_value_
,
values_
,
values_elem_cnt
,
static_cast
<
const
Key
*>
(
keys
),
encoding_buffer_
,
nullptr
,
n_missing
,
static_cast
<
Key
*>
(
missing_keys
),
missing_indices
);
}
template
<
typename
Key
,
typename
Elem
,
typename
Index
,
size_t
pack_size
>
void
CacheImpl
<
Key
,
Elem
,
Index
,
pack_size
>::
Get
(
ep
::
Stream
*
stream
,
uint32_t
n_keys
,
const
void
*
keys
,
void
*
values
,
uint32_t
*
n_missing
,
void
*
missing_keys
,
uint32_t
*
missing_indices
)
{
OF_CUDA_CHECK
(
hipMemsetAsync
(
n_missing
,
0
,
sizeof
(
uint32_t
),
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()));
if
(
n_keys
==
0
)
{
return
;
}
CHECK_LE
(
n_keys
,
max_query_length_
);
constexpr
uint32_t
block_size
=
128
;
uint32_t
grid_size
=
(
n_keys
+
block_size
-
1
)
/
block_size
;
const
uint32_t
values_elem_cnt
=
n_keys
*
num_elem_per_value_
;
EncodeLookupKernel
<
Key
,
Elem
,
Index
,
block_size
>
<<<
grid_size
,
block_size
,
0
,
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()
>>>
(
num_elem_per_value_
,
values_
,
values_elem_cnt
,
static_cast
<
const
Key
*>
(
keys
),
encoding_buffer_
,
static_cast
<
Elem
*>
(
values
),
n_missing
,
static_cast
<
Key
*>
(
missing_keys
),
missing_indices
,
encoder_
.
TableCapacity
(),
encoder_
.
table_keys
(),
encoder_
.
table_indices
());
}
template
<
typename
Key
,
typename
Elem
,
typename
Index
,
size_t
pack_size
>
void
CacheImpl
<
Key
,
Elem
,
Index
,
pack_size
>::
Get
(
ep
::
Stream
*
stream
,
uint32_t
n_keys
,
const
void
*
keys
,
void
*
values
,
uint8_t
*
mask
)
{
if
(
n_keys
==
0
)
{
return
;
}
CHECK_LE
(
n_keys
,
max_query_length_
);
constexpr
uint32_t
block_size
=
128
;
uint32_t
grid_size
=
(
n_keys
+
block_size
-
1
)
/
block_size
;
const
uint32_t
values_elem_cnt
=
n_keys
*
num_elem_per_value_
;
EncodeLookupMaskKernel
<
Key
,
Elem
,
Index
,
block_size
,
pack_size
>
<<<
grid_size
,
block_size
,
0
,
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()
>>>
(
num_elem_per_value_
,
values_
,
values_elem_cnt
,
static_cast
<
const
Key
*>
(
keys
),
encoding_buffer_
,
static_cast
<
Elem
*>
(
values
),
mask
,
encoder_
.
TableCapacity
(),
encoder_
.
table_keys
(),
encoder_
.
table_indices
());
}
template
<
typename
Key
,
typename
Elem
,
typename
Index
,
size_t
pack_size
>
void
CacheImpl
<
Key
,
Elem
,
Index
,
pack_size
>::
Put
(
ep
::
Stream
*
stream
,
uint32_t
n_keys
,
const
void
*
keys
,
const
void
*
values
,
uint32_t
*
n_evicted
,
void
*
evicted_keys
,
void
*
evicted_values
)
{
OF_CUDA_CHECK
(
hipMemsetAsync
(
n_evicted
,
0
,
sizeof
(
uint32_t
),
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()));
if
(
n_keys
==
0
)
{
return
;
}
CHECK_LE
(
n_keys
,
max_query_length_
);
encoder_
.
template
Encode
<
true
>(
stream
,
n_keys
,
static_cast
<
const
Key
*>
(
keys
),
encoding_buffer_
);
const
uint32_t
values_elem_cnt
=
n_keys
*
num_elem_per_value_
;
RUN_CUDA_KERNEL
((
UpdateKernel
<
Elem
,
Index
,
pack_size
>
),
stream
,
values_elem_cnt
/
pack_size
,
num_elem_per_value_
,
values_
,
values_elem_cnt
,
encoding_buffer_
,
static_cast
<
const
Elem
*>
(
values
));
}
template
<
typename
Key
,
typename
Elem
,
typename
Index
,
size_t
pack_size
>
void
CacheImpl
<
Key
,
Elem
,
Index
,
pack_size
>::
FusedHalfUpdatePut
(
ep
::
Stream
*
stream
,
uint32_t
n_keys
,
const
void
*
keys
,
const
void
*
values
,
const
void
*
update
,
const
float
*
lr
,
float
scale
,
uint32_t
*
n_evicted
,
void
*
evicted_keys
,
void
*
evicted_values
)
{
if
(
!
std
::
is_same
<
Elem
,
float
>::
value
)
{
UNIMPLEMENTED
();
}
OF_CUDA_CHECK
(
hipMemsetAsync
(
n_evicted
,
0
,
sizeof
(
uint32_t
),
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()));
if
(
n_keys
==
0
)
{
return
;
}
CHECK_LE
(
n_keys
,
max_query_length_
);
encoder_
.
template
Encode
<
true
>(
stream
,
n_keys
,
static_cast
<
const
Key
*>
(
keys
),
encoding_buffer_
);
const
uint32_t
values_elem_cnt
=
n_keys
*
num_elem_per_value_
;
RUN_CUDA_KERNEL
((
FusedHalfUpdateKernel
<
Elem
,
Index
,
pack_size
>
),
stream
,
values_elem_cnt
/
pack_size
,
num_elem_per_value_
,
values_
,
values_elem_cnt
,
encoding_buffer_
,
static_cast
<
const
Elem
*>
(
values
),
static_cast
<
const
half
*>
(
update
),
lr
,
scale
);
}
template
<
typename
Key
,
typename
Elem
,
typename
Index
,
size_t
pack_size
>
void
CacheImpl
<
Key
,
Elem
,
Index
,
pack_size
>::
Dump
(
ep
::
Stream
*
stream
,
uint64_t
start_key_index
,
uint64_t
end_key_index
,
uint32_t
*
n_dumped
,
void
*
keys
,
void
*
values
)
{
encoder_
.
Dump
(
stream
,
start_key_index
,
end_key_index
,
n_dumped
,
static_cast
<
Key
*>
(
keys
),
encoding_buffer_
);
RUN_CUDA_KERNEL
((
DumpValueKernel
<
Key
,
Elem
,
Index
>
),
stream
,
num_elem_per_value_
*
(
end_key_index
-
start_key_index
),
num_elem_per_value_
,
n_dumped
,
encoding_buffer_
,
values_
,
static_cast
<
Elem
*>
(
values
));
}
template
<
typename
Key
,
typename
Elem
,
typename
Index
,
size_t
pack_size
>
void
CacheImpl
<
Key
,
Elem
,
Index
,
pack_size
>::
Clear
()
{
encoder_
.
Clear
();
}
template
<
typename
Key
,
typename
Index
>
std
::
unique_ptr
<
Cache
>
DispatchValueType
(
const
CacheOptions
&
options
)
{
if
(
options
.
value_type
==
DataType
::
kFloat
)
{
const
size_t
value_elem_cnt
=
options
.
value_size
/
sizeof
(
float
);
const
size_t
half_warp
=
16
;
if
(
value_elem_cnt
%
4
==
0
&&
value_elem_cnt
/
4
>
half_warp
)
{
return
std
::
unique_ptr
<
Cache
>
(
new
CacheImpl
<
Key
,
float
,
Index
,
4
>
(
options
));
}
else
if
(
value_elem_cnt
%
2
==
0
&&
value_elem_cnt
/
2
>
half_warp
)
{
return
std
::
unique_ptr
<
Cache
>
(
new
CacheImpl
<
Key
,
float
,
Index
,
2
>
(
options
));
}
else
{
return
std
::
unique_ptr
<
Cache
>
(
new
CacheImpl
<
Key
,
float
,
Index
,
1
>
(
options
));
}
}
else
if
(
options
.
value_size
%
sizeof
(
ulonglong2
)
==
0
)
{
return
std
::
unique_ptr
<
Cache
>
(
new
CacheImpl
<
Key
,
ulonglong2
,
Index
,
1
>
(
options
));
}
else
if
(
options
.
value_size
%
sizeof
(
uint64_t
)
==
0
)
{
return
std
::
unique_ptr
<
Cache
>
(
new
CacheImpl
<
Key
,
uint64_t
,
Index
,
1
>
(
options
));
}
else
if
(
options
.
value_size
%
sizeof
(
uint32_t
)
==
0
)
{
return
std
::
unique_ptr
<
Cache
>
(
new
CacheImpl
<
Key
,
uint32_t
,
Index
,
1
>
(
options
));
}
else
if
(
options
.
value_size
%
sizeof
(
uint16_t
)
==
0
)
{
return
std
::
unique_ptr
<
Cache
>
(
new
CacheImpl
<
Key
,
uint16_t
,
Index
,
1
>
(
options
));
}
else
{
return
std
::
unique_ptr
<
Cache
>
(
new
CacheImpl
<
Key
,
uint8_t
,
Index
,
1
>
(
options
));
}
}
template
<
typename
Index
>
std
::
unique_ptr
<
Cache
>
DispatchKeyType
(
const
CacheOptions
&
options
)
{
if
(
options
.
key_size
==
sizeof
(
Key32
))
{
return
DispatchValueType
<
Key32
,
Index
>
(
options
);
}
else
if
(
options
.
key_size
==
sizeof
(
Key64
))
{
return
DispatchValueType
<
Key64
,
Index
>
(
options
);
}
else
{
UNIMPLEMENTED
();
return
nullptr
;
}
}
std
::
unique_ptr
<
Cache
>
DispatchIndexType
(
const
CacheOptions
&
options
)
{
const
int64_t
table_capacity
=
static_cast
<
double
>
(
options
.
capacity
)
/
options
.
load_factor
;
if
(
table_capacity
>=
(
1ULL
<<
31ULL
))
{
return
DispatchKeyType
<
uint64_t
>
(
options
);
}
else
{
return
DispatchKeyType
<
uint32_t
>
(
options
);
}
}
}
// namespace
std
::
unique_ptr
<
Cache
>
NewFullCache
(
const
CacheOptions
&
options
)
{
return
DispatchIndexType
(
options
);
}
}
// namespace embedding
}
// namespace oneflow
\ No newline at end of file
oneflow/core/embedding/hash_functions.hip.h
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EMBEDDING_HASH_FUNCTION_HIP_H_
#define ONEFLOW_CORE_EMBEDDING_HASH_FUNCTION_HIP_H_
#include <stdint.h>
#include "oneflow/core/common/data_type.h"
namespace
oneflow
{
namespace
embedding
{
namespace
{
// From https://github.com/Cyan4973/xxHash/blob/dev/xxhash.h
static
const
uint64_t
PRIME64_1
=
0x9E3779B185EBCA87ULL
;
// 0b1001111000110111011110011011000110000101111010111100101010000111
static
const
uint64_t
PRIME64_2
=
0xC2B2AE3D27D4EB4FULL
;
// 0b1100001010110010101011100011110100100111110101001110101101001111
static
const
uint64_t
PRIME64_3
=
0x165667B19E3779F9ULL
;
// 0b0001011001010110011001111011000110011110001101110111100111111001
static
const
uint64_t
PRIME64_4
=
0x85EBCA77C2B2AE63ULL
;
// 0b1000010111101011110010100111011111000010101100101010111001100011
static
const
uint64_t
PRIME64_5
=
0x27D4EB2F165667C5ULL
;
// 0b0010011111010100111010110010111100010110010101100110011111000101
#define XXH_rotl64(x, r) (((x) << (r)) | ((x) >> (64 - (r))))
OF_DEVICE_FUNC
uint64_t
XXH64_round
(
uint64_t
acc
,
uint64_t
input
)
{
acc
+=
input
*
PRIME64_2
;
acc
=
XXH_rotl64
(
acc
,
31
);
acc
*=
PRIME64_1
;
return
acc
;
}
OF_DEVICE_FUNC
uint64_t
xxh64_uint64
(
uint64_t
v
,
uint64_t
seed
)
{
uint64_t
acc
=
seed
+
PRIME64_5
;
acc
+=
sizeof
(
uint64_t
);
acc
=
acc
^
XXH64_round
(
0
,
v
);
acc
=
XXH_rotl64
(
acc
,
27
)
*
PRIME64_1
;
acc
=
acc
+
PRIME64_4
;
acc
^=
(
acc
>>
33
);
acc
=
acc
*
PRIME64_2
;
acc
=
acc
^
(
acc
>>
29
);
acc
=
acc
*
PRIME64_3
;
acc
=
acc
^
(
acc
>>
32
);
return
acc
;
}
static
const
size_t
kShardingHashSeed
=
1
;
static
const
size_t
kLocalUniqueHashSeed
=
2
;
static
const
size_t
kGlobalUniqueHashSeed
=
3
;
static
const
size_t
kFullCacheHashSeed
=
4
;
static
const
size_t
kLruCacheHashSeed
=
5
;
}
// namespace
struct
ShardingHash
{
OF_DEVICE_FUNC
size_t
operator
()(
uint64_t
v
)
{
return
xxh64_uint64
(
v
,
kShardingHashSeed
);
}
OF_DEVICE_FUNC
size_t
operator
()(
uint32_t
v
)
{
return
xxh64_uint64
(
v
,
kShardingHashSeed
);
}
OF_DEVICE_FUNC
size_t
operator
()(
int32_t
v
)
{
return
xxh64_uint64
(
static_cast
<
uint32_t
>
(
v
),
kShardingHashSeed
);
}
OF_DEVICE_FUNC
size_t
operator
()(
int64_t
v
)
{
return
xxh64_uint64
(
static_cast
<
uint64_t
>
(
v
),
kShardingHashSeed
);
}
};
struct
LocalUniqueHash
{
OF_DEVICE_FUNC
size_t
operator
()(
uint64_t
v
)
{
return
xxh64_uint64
(
v
,
kLocalUniqueHashSeed
);
}
};
struct
GlobalUniqueHash
{
OF_DEVICE_FUNC
size_t
operator
()(
uint64_t
v
)
{
return
xxh64_uint64
(
v
,
kGlobalUniqueHashSeed
);
}
};
struct
FullCacheHash
{
OF_DEVICE_FUNC
size_t
operator
()(
uint64_t
v
)
{
return
xxh64_uint64
(
v
,
kFullCacheHashSeed
);
}
};
struct
LruCacheHash
{
OF_DEVICE_FUNC
size_t
operator
()(
uint64_t
v
)
{
return
xxh64_uint64
(
v
,
kLruCacheHashSeed
);
}
};
}
// namespace embedding
}
// namespace oneflow
#endif // ONEFLOW_CORE_EMBEDDING_HASH_FUNCTION_HIP_H_
\ No newline at end of file
oneflow/core/embedding/key_value_store_test.cpp
View file @
a715222c
...
...
@@ -263,6 +263,242 @@ TEST(MockKeyValueStore, Mock) {
#endif // WITH_CUDA
#ifdef WITH_ROCM
std
::
string
CreateTempDirectory
()
{
const
char
*
tmp_env
=
getenv
(
"TMPDIR"
);
const
char
*
tmp_dir
=
tmp_env
==
nullptr
?
"/tmp"
:
tmp_env
;
std
::
string
tpl
=
std
::
string
(
tmp_dir
)
+
"/test_kv_XXXXXX"
;
char
*
path
=
mkdtemp
(
const_cast
<
char
*>
(
tpl
.
c_str
()));
PCHECK
(
path
!=
nullptr
);
return
std
::
string
(
path
);
}
bool
HasCudaDevice
()
{
int
device_count
=
0
;
if
(
hipGetDeviceCount
(
&
device_count
)
!=
hipSuccess
)
{
return
false
;
}
if
(
device_count
<=
0
)
{
return
false
;
}
return
true
;
}
void
TestKeyValueStore
(
KeyValueStore
*
store
,
size_t
num_embeddings
,
size_t
test_embeddings
,
size_t
embedding_vec_size
)
{
auto
device
=
Singleton
<
ep
::
DeviceManagerRegistry
>::
Get
()
->
GetDevice
(
DeviceType
::
kCUDA
,
0
);
ep
::
Stream
*
stream
=
device
->
CreateStream
();
store
->
SaveSnapshot
(
"init"
);
uint64_t
*
keys
=
nullptr
;
float
*
values
=
nullptr
;
float
*
values1
=
nullptr
;
uint64_t
*
keys_host
=
nullptr
;
float
*
values_host
=
nullptr
;
uint64_t
*
context
=
nullptr
;
uint32_t
*
n_missing
=
nullptr
;
uint32_t
*
host_n_missing
=
nullptr
;
uint64_t
*
missing_keys
=
nullptr
;
uint32_t
*
missing_indices
=
nullptr
;
size_t
keys_size
=
sizeof
(
uint64_t
)
*
num_embeddings
;
size_t
values_size
=
sizeof
(
float
)
*
embedding_vec_size
*
num_embeddings
;
size_t
context_size
=
sizeof
(
uint64_t
)
*
num_embeddings
;
const
size_t
batch_size
=
128
;
OF_CUDA_CHECK
(
hipMalloc
(
&
keys
,
keys_size
));
OF_CUDA_CHECK
(
hipMalloc
(
&
values
,
values_size
));
OF_CUDA_CHECK
(
hipMalloc
(
&
values1
,
values_size
));
OF_CUDA_CHECK
(
hipMalloc
(
&
context
,
context_size
));
OF_CUDA_CHECK
(
hipMallocHost
(
reinterpret_cast
<
void
**>
(
&
keys_host
),
keys_size
));
OF_CUDA_CHECK
(
hipMallocHost
(
reinterpret_cast
<
void
**>
(
&
values_host
),
values_size
));
OF_CUDA_CHECK
(
hipMallocHost
(
reinterpret_cast
<
void
**>
(
&
host_n_missing
),
sizeof
(
uint32_t
)));
OF_CUDA_CHECK
(
hipMalloc
(
&
missing_keys
,
batch_size
*
sizeof
(
uint64_t
)));
OF_CUDA_CHECK
(
hipMalloc
(
&
missing_indices
,
batch_size
*
sizeof
(
uint32_t
)));
OF_CUDA_CHECK
(
hipMalloc
(
&
n_missing
,
sizeof
(
uint32_t
)));
for
(
size_t
i
=
0
;
i
<
num_embeddings
;
++
i
)
{
uint64_t
key
=
i
+
1
;
keys_host
[
i
]
=
key
;
for
(
size_t
j
=
0
;
j
<
embedding_vec_size
;
j
++
)
{
values_host
[
i
*
embedding_vec_size
+
j
]
=
key
;
}
}
OF_CUDA_CHECK
(
hipMemcpy
(
keys
,
keys_host
,
keys_size
,
hipMemcpyDefault
));
OF_CUDA_CHECK
(
hipMemcpy
(
values
,
values_host
,
values_size
,
hipMemcpyDefault
));
store
->
Put
(
stream
,
0
,
keys
,
values
);
OF_CUDA_CHECK
(
hipDeviceSynchronize
());
OF_CUDA_CHECK
(
hipGetLastError
());
for
(
size_t
offset
=
0
;
offset
<
test_embeddings
;
offset
+=
batch_size
)
{
const
size_t
num_keys
=
std
::
min
(
batch_size
,
test_embeddings
-
offset
);
store
->
Get
(
stream
,
num_keys
,
keys
+
offset
,
values1
+
offset
*
embedding_vec_size
,
n_missing
,
missing_indices
);
OF_CUDA_CHECK
(
hipMemcpy
(
host_n_missing
,
n_missing
,
sizeof
(
uint32_t
),
hipMemcpyDefault
));
OF_CUDA_CHECK
(
hipDeviceSynchronize
());
ASSERT_EQ
(
*
host_n_missing
,
num_keys
);
store
->
Put
(
stream
,
num_keys
,
keys
+
offset
,
values
+
offset
*
embedding_vec_size
);
}
OF_CUDA_CHECK
(
hipDeviceSynchronize
());
store
->
SaveSnapshot
(
"final"
);
OF_CUDA_CHECK
(
hipMemset
(
values_host
,
0
,
values_size
));
OF_CUDA_CHECK
(
hipMemset
(
values
,
0
,
values_size
));
for
(
size_t
offset
=
0
;
offset
<
test_embeddings
;
offset
+=
batch_size
)
{
const
size_t
num_keys
=
std
::
min
(
batch_size
,
test_embeddings
-
offset
);
store
->
Get
(
stream
,
num_keys
,
keys
+
offset
,
values
+
offset
*
embedding_vec_size
,
n_missing
,
missing_indices
);
OF_CUDA_CHECK
(
hipMemcpy
(
host_n_missing
,
n_missing
,
sizeof
(
uint32_t
),
hipMemcpyDefault
));
OF_CUDA_CHECK
(
hipDeviceSynchronize
());
ASSERT_EQ
(
*
host_n_missing
,
0
);
}
OF_CUDA_CHECK
(
hipMemcpy
(
values_host
,
values
,
values_size
,
hipMemcpyDefault
));
OF_CUDA_CHECK
(
hipDeviceSynchronize
());
for
(
size_t
i
=
0
;
i
<
test_embeddings
;
++
i
)
{
uint64_t
key
=
keys_host
[
i
];
for
(
size_t
j
=
0
;
j
<
embedding_vec_size
;
j
++
)
{
ASSERT_EQ
(
values_host
[
i
*
embedding_vec_size
+
j
],
key
);
}
}
store
->
LoadSnapshot
(
"init"
);
for
(
size_t
offset
=
0
;
offset
<
test_embeddings
;
offset
+=
batch_size
)
{
const
size_t
num_keys
=
std
::
min
(
batch_size
,
test_embeddings
-
offset
);
store
->
Get
(
stream
,
num_keys
,
keys
+
offset
,
values1
+
offset
*
embedding_vec_size
,
n_missing
,
missing_indices
);
OF_CUDA_CHECK
(
hipMemcpy
(
host_n_missing
,
n_missing
,
sizeof
(
uint32_t
),
hipMemcpyDefault
));
OF_CUDA_CHECK
(
hipDeviceSynchronize
());
ASSERT_EQ
(
*
host_n_missing
,
num_keys
);
}
store
->
LoadSnapshot
(
"final"
);
OF_CUDA_CHECK
(
hipMemset
(
values_host
,
0
,
values_size
));
OF_CUDA_CHECK
(
hipMemset
(
values
,
0
,
values_size
));
for
(
size_t
offset
=
0
;
offset
<
test_embeddings
;
offset
+=
batch_size
)
{
const
size_t
num_keys
=
std
::
min
(
batch_size
,
test_embeddings
-
offset
);
store
->
Get
(
stream
,
num_keys
,
keys
+
offset
,
values
+
offset
*
embedding_vec_size
,
n_missing
,
missing_indices
);
OF_CUDA_CHECK
(
hipMemcpy
(
host_n_missing
,
n_missing
,
sizeof
(
uint32_t
),
hipMemcpyDefault
));
OF_CUDA_CHECK
(
hipDeviceSynchronize
());
ASSERT_EQ
(
*
host_n_missing
,
0
);
}
OF_CUDA_CHECK
(
hipMemcpy
(
values_host
,
values
,
values_size
,
hipMemcpyDefault
));
OF_CUDA_CHECK
(
hipDeviceSynchronize
());
for
(
size_t
i
=
0
;
i
<
test_embeddings
;
++
i
)
{
uint64_t
key
=
keys_host
[
i
];
for
(
size_t
j
=
0
;
j
<
embedding_vec_size
;
j
++
)
{
ASSERT_EQ
(
values_host
[
i
*
embedding_vec_size
+
j
],
key
);
}
}
OF_CUDA_CHECK
(
hipDeviceSynchronize
());
OF_CUDA_CHECK
(
hipGetLastError
());
OF_CUDA_CHECK
(
hipFree
(
keys
));
OF_CUDA_CHECK
(
hipFree
(
values
));
OF_CUDA_CHECK
(
hipFree
(
values1
));
OF_CUDA_CHECK
(
hipHostFree
(
keys_host
));
OF_CUDA_CHECK
(
hipHostFree
(
values_host
));
OF_CUDA_CHECK
(
hipHostFree
(
host_n_missing
));
OF_CUDA_CHECK
(
hipFree
(
n_missing
));
OF_CUDA_CHECK
(
hipFree
(
missing_keys
));
OF_CUDA_CHECK
(
hipFree
(
missing_indices
));
CHECK_JUST
(
stream
->
Sync
());
device
->
DestroyStream
(
stream
);
}
TEST
(
PersistentTableKeyValueStore
,
PersistentTableKeyValueStore
)
{
if
(
!
HasCudaDevice
())
{
return
;
}
Singleton
<
ep
::
DeviceManagerRegistry
>::
New
();
PersistentTableKeyValueStoreOptions
options
{};
uint32_t
value_length
=
128
;
std
::
string
path
=
CreateTempDirectory
();
options
.
table_options
.
path
=
path
;
options
.
table_options
.
value_size
=
value_length
*
sizeof
(
float
);
options
.
table_options
.
key_size
=
GetSizeOfDataType
(
DataType
::
kUInt64
);
options
.
table_options
.
physical_block_size
=
512
;
std
::
unique_ptr
<
KeyValueStore
>
store
=
NewPersistentTableKeyValueStore
(
options
);
store
->
ReserveQueryLength
(
128
);
TestKeyValueStore
(
store
.
get
(),
1024
,
1024
,
value_length
);
store
.
reset
();
PosixFile
::
RecursiveDelete
(
path
);
Singleton
<
ep
::
DeviceManagerRegistry
>::
Delete
();
}
// TEST(CachedKeyValueStore, LRU) {
// if (!HasCudaDevice()) { return; }
// Singleton<ep::DeviceManagerRegistry>::New();
// PersistentTableKeyValueStoreOptions store_options{};
// std::string path = CreateTempDirectory();
// store_options.table_options.path = path;
// uint32_t value_length = 128;
// store_options.table_options.value_size = value_length * sizeof(float);
// store_options.table_options.key_size = GetSizeOfDataType(DataType::kUInt64);
// store_options.table_options.physical_block_size = 512;
// std::unique_ptr<KeyValueStore> store = NewPersistentTableKeyValueStore(store_options);
// CacheOptions cache_options{};
// cache_options.policy = CacheOptions::Policy::kLRU;
// cache_options.value_memory_kind = CacheOptions::MemoryKind::kDevice;
// cache_options.value_size = 512;
// cache_options.capacity = 512;
// cache_options.key_size = 8;
// std::unique_ptr<Cache> cache = NewCache(cache_options);
// std::unique_ptr<KeyValueStore> cached_store =
// NewCachedKeyValueStore(std::move(store), std::move(cache));
// cached_store->ReserveQueryLength(128);
// TestKeyValueStore(cached_store.get(), 1024, 1024, value_length);
// cached_store.reset();
// PosixFile::RecursiveDelete(path);
// Singleton<ep::DeviceManagerRegistry>::Delete();
// }
TEST
(
CachedKeyValueStore
,
Full
)
{
if
(
!
HasCudaDevice
())
{
return
;
}
Singleton
<
ep
::
DeviceManagerRegistry
>::
New
();
PersistentTableKeyValueStoreOptions
store_options
{};
std
::
string
path
=
CreateTempDirectory
();
store_options
.
table_options
.
path
=
path
;
uint32_t
value_length
=
128
;
store_options
.
table_options
.
value_size
=
value_length
*
sizeof
(
float
);
store_options
.
table_options
.
key_size
=
GetSizeOfDataType
(
DataType
::
kUInt64
);
store_options
.
table_options
.
physical_block_size
=
512
;
std
::
unique_ptr
<
KeyValueStore
>
store
=
NewPersistentTableKeyValueStore
(
store_options
);
CacheOptions
cache_options
{};
cache_options
.
policy
=
CacheOptions
::
Policy
::
kFull
;
cache_options
.
value_memory_kind
=
CacheOptions
::
MemoryKind
::
kHost
;
cache_options
.
value_size
=
512
;
cache_options
.
capacity
=
1024
*
2
;
cache_options
.
key_size
=
8
;
std
::
unique_ptr
<
Cache
>
cache
=
NewCache
(
cache_options
);
std
::
unique_ptr
<
KeyValueStore
>
cached_store
=
NewCachedKeyValueStore
(
std
::
move
(
store
),
std
::
move
(
cache
));
cached_store
->
ReserveQueryLength
(
128
);
TestKeyValueStore
(
cached_store
.
get
(),
1024
,
1024
,
value_length
);
cached_store
.
reset
();
PosixFile
::
RecursiveDelete
(
path
);
Singleton
<
ep
::
DeviceManagerRegistry
>::
Delete
();
}
TEST
(
MockKeyValueStore
,
Mock
)
{
if
(
!
HasCudaDevice
())
{
return
;
}
Singleton
<
ep
::
DeviceManagerRegistry
>::
New
();
MockKeyValueStoreOptions
store_options
{};
std
::
string
path
=
CreateTempDirectory
();
uint32_t
value_length
=
128
;
store_options
.
value_size
=
value_length
*
sizeof
(
float
);
store_options
.
key_size
=
GetSizeOfDataType
(
DataType
::
kUInt64
);
std
::
unique_ptr
<
KeyValueStore
>
store
=
NewMockKeyValueStore
(
store_options
);
store
->
ReserveQueryLength
(
128
);
TestKeyValueStore
(
store
.
get
(),
1024
,
1024
,
value_length
);
store
.
reset
();
PosixFile
::
RecursiveDelete
(
path
);
Singleton
<
ep
::
DeviceManagerRegistry
>::
Delete
();
}
#endif // WITH_ROCM
}
// namespace
}
// namespace embedding
...
...
oneflow/core/embedding/lru_cache.cu
View file @
a715222c
...
...
@@ -20,9 +20,14 @@ limitations under the License.
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/embedding/hash_functions.cuh"
#include <new>
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#else
#include <cuda.h>
#endif
#if CUDA_VERSION >= 11000 && ((!defined(__CUDA_ARCH__)) || (__CUDA_ARCH__ >= 700))
#if CUDA_VERSION >= 11000 && ((!defined(__CUDA_ARCH__)) || (__CUDA_ARCH__ >= 700)) \
&& !(defined(__clang__) && defined(__CUDA__))
#include <cuda/std/semaphore>
#endif
...
...
@@ -32,7 +37,11 @@ namespace embedding {
namespace
{
#ifdef WITH_ROCM
constexpr
int
kWarpSize
=
64
;
#else
constexpr
int
kWarpSize
=
32
;
#endif
constexpr
int
kNumWarpPerBlock
=
4
;
constexpr
int
kBlockSize
=
kNumWarpPerBlock
*
kWarpSize
;
constexpr
uint32_t
kFullMask
=
0xFFFFFFFFU
;
...
...
@@ -69,11 +78,19 @@ class WarpMutexAtomicImpl {
;
}
__threadfence
();
__syncwarp
();
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
}
__device__
void
Unlock
(
const
ThreadContext
&
thread_ctx
)
{
__syncwarp
();
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
__threadfence
();
if
(
thread_ctx
.
lane_id
==
0
)
{
atomicExch
(
&
flag_
,
0
);
}
}
...
...
@@ -82,7 +99,8 @@ class WarpMutexAtomicImpl {
int32_t
flag_
;
};
#if CUDA_VERSION >= 11000 && ((!defined(__CUDA_ARCH__)) || (__CUDA_ARCH__ >= 700))
#if CUDA_VERSION >= 11000 && ((!defined(__CUDA_ARCH__)) || (__CUDA_ARCH__ >= 700)) \
&& !(defined(__clang__) && defined(__CUDA__))
class
WarpMutexSemaphoreImpl
{
public:
...
...
@@ -92,11 +110,19 @@ class WarpMutexSemaphoreImpl {
__device__
void
Lock
(
const
ThreadContext
&
thread_ctx
)
{
if
(
thread_ctx
.
lane_id
==
0
)
{
semaphore_
.
acquire
();
}
__syncwarp
();
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
}
__device__
void
Unlock
(
const
ThreadContext
&
thread_ctx
)
{
__syncwarp
();
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
if
(
thread_ctx
.
lane_id
==
0
)
{
semaphore_
.
release
();
}
}
...
...
@@ -118,19 +144,20 @@ struct LruCacheContext {
};
__global__
void
InitCacheSetMutex
(
uint32_t
n_set
,
void
*
mutex
)
{
#if CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 700
#if CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 700
&& !(defined(__clang__) && defined(__CUDA__))
using
WarpMutex
=
WarpMutexSemaphoreImpl
;
#else
using
WarpMutex
=
WarpMutexAtomicImpl
;
#endif // CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 700
#endif // CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 700 && !(defined(__clang__) &&
// defined(__CUDA__))
const
uint32_t
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
idx
<
n_set
)
{
new
(
reinterpret_cast
<
WarpMutex
*>
(
mutex
)
+
idx
)
WarpMutex
;
}
}
template
<
typename
Key
,
typename
Elem
>
void
ClearLruCacheContext
(
LruCacheContext
<
Key
,
Elem
>*
ctx
)
{
OF_CUDA_CHECK
(
cuda
Memset
(
ctx
->
keys
,
0
,
ctx
->
n_set
*
kWarpSize
*
sizeof
(
Key
)));
OF_CUDA_CHECK
(
cuda
Memset
(
ctx
->
ages
,
0
,
ctx
->
n_set
*
kWarpSize
*
sizeof
(
uint8_t
)));
OF_CUDA_CHECK
(
GPU
(
Memset
)
(
ctx
->
keys
,
0
,
ctx
->
n_set
*
kWarpSize
*
sizeof
(
Key
)));
OF_CUDA_CHECK
(
GPU
(
Memset
)
(
ctx
->
ages
,
0
,
ctx
->
n_set
*
kWarpSize
*
sizeof
(
uint8_t
)));
InitCacheSetMutex
<<<
(
ctx
->
n_set
-
1
+
256
)
/
256
,
256
>>>
(
ctx
->
n_set
,
ctx
->
mutex
);
}
...
...
@@ -141,11 +168,13 @@ void InitLruCacheContext(const CacheOptions& options, LruCacheContext<Key, Elem>
const
size_t
lines_size_per_set
=
kWarpSize
*
line_size
*
sizeof
(
Elem
);
const
size_t
ages_size_per_set
=
kWarpSize
*
sizeof
(
uint8_t
);
int
device
=
0
;
OF_CUDA_CHECK
(
cuda
GetDevice
(
&
device
));
OF_CUDA_CHECK
(
GPU
(
GetDevice
)
(
&
device
));
int
major
=
0
;
#ifdef WITH_CUDA
OF_CUDA_CHECK
(
cudaDeviceGetAttribute
(
&
major
,
cudaDevAttrComputeCapabilityMajor
,
device
));
#endif
size_t
mutex_size_per_set
=
0
;
#if CUDA_VERSION >= 11000
#if CUDA_VERSION >= 11000
&& !(defined(__clang__) && defined(__CUDA__))
if
(
major
>=
7
)
{
#if !defined(__CUDA_ARCH__)
mutex_size_per_set
=
sizeof
(
WarpMutexSemaphoreImpl
);
...
...
@@ -157,19 +186,23 @@ void InitLruCacheContext(const CacheOptions& options, LruCacheContext<Key, Elem>
}
#else
mutex_size_per_set
=
sizeof
(
WarpMutexAtomicImpl
);
#endif // CUDA_VERSION >= 11000
#endif // CUDA_VERSION >= 11000
&& !(defined(__clang__) && defined(__CUDA__))
const
size_t
n_set
=
(
options
.
capacity
-
1
+
kWarpSize
)
/
kWarpSize
;
CHECK_GT
(
n_set
,
0
);
ctx
->
n_set
=
n_set
;
ctx
->
line_size
=
line_size
;
const
size_t
keys_size
=
n_set
*
keys_size_per_set
;
OF_CUDA_CHECK
(
cuda
Malloc
(
&
(
ctx
->
keys
),
keys_size
));
OF_CUDA_CHECK
(
GPU
(
Malloc
)
(
&
(
ctx
->
keys
),
keys_size
));
const
size_t
lines_size
=
n_set
*
lines_size_per_set
;
if
(
options
.
value_memory_kind
==
CacheOptions
::
MemoryKind
::
kDevice
)
{
OF_CUDA_CHECK
(
cuda
Malloc
(
&
(
ctx
->
lines
),
lines_size
));
OF_CUDA_CHECK
(
GPU
(
Malloc
)
(
&
(
ctx
->
lines
),
lines_size
));
}
else
if
(
options
.
value_memory_kind
==
CacheOptions
::
MemoryKind
::
kHost
)
{
if
(
ParseBooleanFromEnv
(
"ONEFLOW_ONE_EMBEDDING_DISABLE_NUMA_AWARE_ALLOCATION"
,
false
))
{
#ifdef WITH_ROCM
OF_CUDA_CHECK
(
hipMallocHost
(
reinterpret_cast
<
void
**>
(
&
(
ctx
->
lines
)),
lines_size
));
#else
OF_CUDA_CHECK
(
cudaMallocHost
(
&
(
ctx
->
lines
),
lines_size
));
#endif
}
else
{
OF_CUDA_CHECK
(
NumaAwareCudaMallocHost
(
device
,
reinterpret_cast
<
void
**>
(
&
ctx
->
lines
),
lines_size
));
...
...
@@ -179,45 +212,50 @@ void InitLruCacheContext(const CacheOptions& options, LruCacheContext<Key, Elem>
}
ctx
->
value_memory_kind
=
options
.
value_memory_kind
;
const
size_t
ages_size
=
n_set
*
ages_size_per_set
;
OF_CUDA_CHECK
(
cuda
Malloc
(
&
(
ctx
->
ages
),
ages_size
));
OF_CUDA_CHECK
(
GPU
(
Malloc
)
(
&
(
ctx
->
ages
),
ages_size
));
const
size_t
mutex_size
=
n_set
*
mutex_size_per_set
;
OF_CUDA_CHECK
(
cuda
Malloc
(
&
(
ctx
->
mutex
),
mutex_size
));
OF_CUDA_CHECK
(
GPU
(
Malloc
)
(
&
(
ctx
->
mutex
),
mutex_size
));
ClearLruCacheContext
(
ctx
);
}
template
<
typename
Key
,
typename
Elem
>
void
DestroyLruCacheContext
(
LruCacheContext
<
Key
,
Elem
>*
ctx
)
{
OF_CUDA_CHECK
(
cuda
Free
(
ctx
->
keys
));
OF_CUDA_CHECK
(
GPU
(
Free
)
(
ctx
->
keys
));
if
(
ctx
->
value_memory_kind
==
CacheOptions
::
MemoryKind
::
kDevice
)
{
OF_CUDA_CHECK
(
cuda
Free
(
ctx
->
lines
));
OF_CUDA_CHECK
(
GPU
(
Free
)
(
ctx
->
lines
));
}
else
if
(
ctx
->
value_memory_kind
==
CacheOptions
::
MemoryKind
::
kHost
)
{
OF_CUDA_CHECK
(
cuda
FreeHost
(
ctx
->
lines
));
OF_CUDA_CHECK
(
GPU
(
FreeHost
)
(
ctx
->
lines
));
}
else
{
UNIMPLEMENTED
();
}
OF_CUDA_CHECK
(
cuda
Free
(
ctx
->
ages
));
OF_CUDA_CHECK
(
cuda
Free
(
ctx
->
mutex
));
OF_CUDA_CHECK
(
GPU
(
Free
)
(
ctx
->
ages
));
OF_CUDA_CHECK
(
GPU
(
Free
)
(
ctx
->
mutex
));
}
template
<
typename
Key
,
typename
Elem
>
struct
SetContext
{
#if CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 700
#if CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 700
&& !(defined(__clang__) && defined(__CUDA__))
using
WarpMutex
=
WarpMutexSemaphoreImpl
;
#else
using
WarpMutex
=
WarpMutexAtomicImpl
;
#endif // CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 700
#endif // CUDA_VERSION >= 11000 && __CUDA_ARCH__ >= 700 && !(defined(__clang__) &&
// defined(__CUDA__))
__device__
SetContext
(
const
LruCacheContext
<
Key
,
Elem
>&
ctx
,
uint32_t
set_id
)
:
keys
(
ctx
.
keys
+
set_id
*
kWarpSize
),
mutex
(
reinterpret_cast
<
WarpMutex
*>
(
ctx
.
mutex
)
+
set_id
),
ages
(
ctx
.
ages
+
set_id
*
kWarpSize
),
lines
(
ctx
.
lines
+
set_id
*
kWarpSize
*
ctx
.
line_size
)
{}
lines
(
ctx
.
lines
+
static_cast
<
size_t
>
(
set_id
)
*
kWarpSize
*
ctx
.
line_size
)
{}
__device__
int
Lookup
(
const
ThreadContext
&
thread_ctx
,
Key
key
)
{
const
Key
lane_key
=
keys
[
thread_ctx
.
lane_id
];
const
int
lane_age
=
ages
[
thread_ctx
.
lane_id
];
const
bool
lane_hit
=
(
lane_key
==
key
&&
lane_age
!=
0
);
#ifdef WITH_ROCM
const
unsigned
hit_mask
=
__ballot
(
lane_hit
);
#else
const
unsigned
hit_mask
=
__ballot_sync
(
kFullMask
,
lane_hit
);
#endif
if
(
hit_mask
!=
0
)
{
return
__ffs
(
static_cast
<
int
>
(
hit_mask
))
-
1
;
}
else
{
...
...
@@ -238,19 +276,35 @@ struct SetContext {
int
insert_way
=
-
1
;
const
Key
lane_key
=
keys
[
thread_ctx
.
lane_id
];
int
lane_age
=
ages
[
thread_ctx
.
lane_id
];
#ifdef WITH_ROCM
const
unsigned
hit_mask
=
__ballot
(
lane_key
==
key
&&
lane_age
!=
0
);
#else
const
unsigned
hit_mask
=
__ballot_sync
(
kFullMask
,
lane_key
==
key
&&
lane_age
!=
0
);
#endif
if
(
hit_mask
!=
0
)
{
insert_way
=
__ffs
(
static_cast
<
int
>
(
hit_mask
))
-
1
;
#ifdef WITH_ROCM
const
int
insert_way_age
=
__shfl
(
lane_age
,
insert_way
);
#else
const
int
insert_way_age
=
__shfl_sync
(
kFullMask
,
lane_age
,
insert_way
);
#endif
if
(
lane_age
>
insert_way_age
)
{
lane_age
-=
1
;
}
else
if
(
thread_ctx
.
lane_id
==
insert_way
)
{
lane_age
=
kWarpSize
;
}
__syncwarp
();
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
}
if
(
insert_way
==
-
1
)
{
const
unsigned
valid_mask
=
__ballot_sync
(
kFullMask
,
lane_age
!=
0
);
#ifdef WITH_ROCM
const
unsigned
valid_mask
=
__ballot
(
lane_age
!=
0
);
#else
const
unsigned
valid_mask
=
__ballot_sync
(
kFullMask
,
lane_age
!=
0
);
#endif
if
(
valid_mask
!=
kFullMask
)
{
insert_way
=
__popc
(
static_cast
<
int
>
(
valid_mask
));
if
(
lane_age
>
0
)
{
...
...
@@ -259,7 +313,11 @@ struct SetContext {
lane_age
=
kWarpSize
;
keys
[
insert_way
]
=
key
;
}
__syncwarp
();
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
}
}
if
(
insert_way
!=
-
1
)
{
ages
[
thread_ctx
.
lane_id
]
=
lane_age
;
}
...
...
@@ -270,15 +328,28 @@ struct SetContext {
const
ThreadContext
&
thread_ctx
,
Key
key
,
int
*
way
,
Key
*
evicted_key
)
{
const
Key
lane_key
=
keys
[
thread_ctx
.
lane_id
];
int
lane_age
=
ages
[
thread_ctx
.
lane_id
];
#ifdef WITH_ROCM
const
int
insert_way
=
__ffs
(
static_cast
<
int
>
(
__ballot
(
lane_age
==
1
)))
-
1
;
#else
const
int
insert_way
=
__ffs
(
__ballot_sync
(
kFullMask
,
lane_age
==
1
))
-
1
;
#endif
#ifdef WITH_ROCM
*
evicted_key
=
__shfl
(
lane_key
,
insert_way
);
#else
*
evicted_key
=
__shfl_sync
(
kFullMask
,
lane_key
,
insert_way
);
#endif
if
(
thread_ctx
.
lane_id
==
insert_way
)
{
keys
[
insert_way
]
=
key
;
lane_age
=
kWarpSize
;
}
else
if
(
lane_age
>
1
)
{
lane_age
-=
1
;
}
__syncwarp
();
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
ages
[
thread_ctx
.
lane_id
]
=
lane_age
;
*
way
=
insert_way
;
}
...
...
@@ -318,7 +389,11 @@ __global__ void GetKernel(LruCacheContext<Key, Elem> cache_ctx, uint32_t num_key
block_keys
[
thread_ctx
.
warp_id_in_block
][
thread_ctx
.
lane_id
]
=
key
;
block_set_ids
[
thread_ctx
.
warp_id_in_block
][
thread_ctx
.
lane_id
]
=
set_id
;
}
__syncwarp
();
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
uint32_t
n_warp_missing
=
0
;
Key
warp_missing_key
=
0
;
uint32_t
warp_missing_index
=
0
;
...
...
@@ -333,7 +408,11 @@ __global__ void GetKernel(LruCacheContext<Key, Elem> cache_ctx, uint32_t num_key
warp_missing_key
=
key
;
warp_missing_index
=
key_idx
;
}
__syncwarp
();
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
n_warp_missing
+=
1
;
}
else
if
(
!
test_only
)
{
set_ctx
.
Read
(
cache_ctx
,
thread_ctx
,
way
,
values
+
key_idx
*
cache_ctx
.
line_size
);
...
...
@@ -342,15 +421,31 @@ __global__ void GetKernel(LruCacheContext<Key, Elem> cache_ctx, uint32_t num_key
if
(
n_warp_missing
>
0
)
{
uint32_t
base_missing_idx
=
0
;
if
(
thread_ctx
.
lane_id
==
0
)
{
base_missing_idx
=
atomicAdd
(
n_missing_keys
,
n_warp_missing
);
}
__syncwarp
();
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
#ifdef WITH_ROCM
base_missing_idx
=
__shfl
(
base_missing_idx
,
0
);
#else
base_missing_idx
=
__shfl_sync
(
kFullMask
,
base_missing_idx
,
0
);
#endif
if
(
thread_ctx
.
lane_id
<
n_warp_missing
)
{
missing_keys
[
base_missing_idx
+
thread_ctx
.
lane_id
]
=
warp_missing_key
;
missing_indices
[
base_missing_idx
+
thread_ctx
.
lane_id
]
=
warp_missing_index
;
}
__syncwarp
();
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
}
__syncwarp
();
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
}
}
...
...
@@ -371,7 +466,11 @@ __global__ void PutWithoutEvictingKernel(LruCacheContext<Key, Elem> cache_ctx, u
block_keys
[
thread_ctx
.
warp_id_in_block
][
thread_ctx
.
lane_id
]
=
key
;
block_set_ids
[
thread_ctx
.
warp_id_in_block
][
thread_ctx
.
lane_id
]
=
set_id
;
}
__syncwarp
();
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
uint32_t
n_warp_missing
=
0
;
Key
warp_missing_key
=
0
;
uint32_t
warp_missing_index
=
0
;
...
...
@@ -390,7 +489,11 @@ __global__ void PutWithoutEvictingKernel(LruCacheContext<Key, Elem> cache_ctx, u
warp_missing_key
=
key
;
warp_missing_index
=
key_idx
;
}
__syncwarp
();
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
n_warp_missing
+=
1
;
}
set_ctx
.
Unlock
(
thread_ctx
);
...
...
@@ -398,13 +501,25 @@ __global__ void PutWithoutEvictingKernel(LruCacheContext<Key, Elem> cache_ctx, u
if
(
n_warp_missing
>
0
)
{
uint32_t
base_missing_idx
=
0
;
if
(
thread_ctx
.
lane_id
==
0
)
{
base_missing_idx
=
atomicAdd
(
n_missing
,
n_warp_missing
);
}
__syncwarp
();
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
#ifdef WITH_ROCM
base_missing_idx
=
__shfl
(
base_missing_idx
,
0
);
#else
base_missing_idx
=
__shfl_sync
(
kFullMask
,
base_missing_idx
,
0
);
#endif
if
(
thread_ctx
.
lane_id
<
n_warp_missing
)
{
missing_keys
[
base_missing_idx
+
thread_ctx
.
lane_id
]
=
warp_missing_key
;
missing_indices
[
base_missing_idx
+
thread_ctx
.
lane_id
]
=
warp_missing_index
;
}
__syncwarp
();
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
}
}
}
...
...
@@ -427,7 +542,11 @@ __global__ void EvictKernel(LruCacheContext<Key, Elem> cache_ctx, const Key* key
block_keys
[
thread_ctx
.
warp_id_in_block
][
thread_ctx
.
lane_id
]
=
key
;
block_set_ids
[
thread_ctx
.
warp_id_in_block
][
thread_ctx
.
lane_id
]
=
set_id
;
}
__syncwarp
();
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
for
(
uint32_t
i
=
0
;
i
<
n_batch_keys
;
++
i
)
{
const
uint32_t
key_idx
=
batch_offset
+
i
;
const
Key
key
=
block_keys
[
thread_ctx
.
warp_id_in_block
][
i
];
...
...
@@ -438,7 +557,11 @@ __global__ void EvictKernel(LruCacheContext<Key, Elem> cache_ctx, const Key* key
Key
evicted_key
=
0
;
set_ctx
.
Evict
(
cache_ctx
,
thread_ctx
,
key
,
&
evicted_way
,
&
evicted_key
);
if
(
thread_ctx
.
lane_id
==
0
)
{
evicted_keys
[
key_idx
]
=
evicted_key
;
}
__syncwarp
();
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
set_ctx
.
Read
(
cache_ctx
,
thread_ctx
,
evicted_way
,
evicted_values
+
cache_ctx
.
line_size
*
key_idx
);
set_ctx
.
Write
(
cache_ctx
,
thread_ctx
,
evicted_way
,
...
...
@@ -463,26 +586,52 @@ __global__ void DumpKernel(LruCacheContext<Key, Elem> cache_ctx, size_t start_ke
lane_key
=
cache_ctx
.
keys
[
warp_start_key_index
+
thread_ctx
.
lane_id
];
lane_age
=
cache_ctx
.
ages
[
warp_start_key_index
+
thread_ctx
.
lane_id
];
}
__syncwarp
();
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
warp_keys
[
thread_ctx
.
warp_id_in_block
][
thread_ctx
.
lane_id
]
=
lane_key
;
warp_ages
[
thread_ctx
.
warp_id_in_block
][
thread_ctx
.
lane_id
]
=
lane_age
;
#ifdef WITH_ROCM
const
int
key_count
=
__popc
(
static_cast
<
int
>
(
__ballot
(
lane_age
!=
0
)));
#else
const
int
key_count
=
__popc
(
__ballot_sync
(
kFullMask
,
lane_age
!=
0
));
#endif
if
(
key_count
==
0
)
{
continue
;
}
uint32_t
offset
=
0
;
if
(
thread_ctx
.
lane_id
==
0
)
{
offset
=
atomicAdd
(
n_dumped
,
key_count
);
}
#ifdef WITH_ROCM
offset
=
__shfl
(
offset
,
0
);
#else
offset
=
__shfl_sync
(
kFullMask
,
offset
,
0
);
__syncwarp
();
#endif
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
for
(
uint32_t
i
=
0
;
i
<
kWarpSize
;
++
i
)
{
const
Key
key
=
warp_keys
[
thread_ctx
.
warp_id_in_block
][
i
];
const
Key
age
=
warp_ages
[
thread_ctx
.
warp_id_in_block
][
i
];
if
(
age
==
0
)
{
continue
;
}
if
(
thread_ctx
.
lane_id
==
0
)
{
keys
[
offset
]
=
key
;
}
__syncwarp
();
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
for
(
uint32_t
j
=
thread_ctx
.
lane_id
;
j
<
cache_ctx
.
line_size
;
j
+=
kWarpSize
)
{
values
[
offset
*
cache_ctx
.
line_size
+
j
]
=
cache_ctx
.
lines
[(
warp_start_key_index
+
i
)
*
cache_ctx
.
line_size
+
j
];
cache_ctx
.
lines
[
static_cast
<
size_t
>
(
warp_start_key_index
+
i
)
*
cache_ctx
.
line_size
+
j
];
}
__syncwarp
();
#ifdef WITH_ROCM
__syncthreads
();
#else
__syncwarp
();
#endif
offset
+=
1
;
}
}
...
...
@@ -498,14 +647,14 @@ class LruCache : public Cache {
query_indices_buffer_
(
nullptr
),
query_keys_buffer_
(
nullptr
),
value_type_
(
options
.
value_type
)
{
OF_CUDA_CHECK
(
cuda
GetDevice
(
&
device_index_
));
OF_CUDA_CHECK
(
GPU
(
GetDevice
)
(
&
device_index_
));
InitLruCacheContext
(
options
,
&
ctx_
);
}
~
LruCache
()
override
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
if
(
max_query_length_
!=
0
)
{
OF_CUDA_CHECK
(
cuda
Free
(
query_indices_buffer_
));
OF_CUDA_CHECK
(
cuda
Free
(
query_keys_buffer_
));
OF_CUDA_CHECK
(
GPU
(
Free
)
(
query_indices_buffer_
));
OF_CUDA_CHECK
(
GPU
(
Free
)
(
query_keys_buffer_
));
}
DestroyLruCacheContext
(
&
ctx_
);
}
...
...
@@ -520,11 +669,11 @@ class LruCache : public Cache {
CudaCurrentDeviceGuard
guard
(
device_index_
);
if
(
query_length
<
max_query_length_
)
{
return
;
}
if
(
max_query_length_
!=
0
)
{
OF_CUDA_CHECK
(
cuda
Free
(
query_indices_buffer_
));
OF_CUDA_CHECK
(
cuda
Free
(
query_keys_buffer_
));
OF_CUDA_CHECK
(
GPU
(
Free
)
(
query_indices_buffer_
));
OF_CUDA_CHECK
(
GPU
(
Free
)
(
query_keys_buffer_
));
}
OF_CUDA_CHECK
(
cuda
Malloc
(
&
query_indices_buffer_
,
query_length
*
sizeof
(
uint32_t
)));
OF_CUDA_CHECK
(
cuda
Malloc
(
&
query_keys_buffer_
,
query_length
*
sizeof
(
Key
)));
OF_CUDA_CHECK
(
GPU
(
Malloc
)
(
&
query_indices_buffer_
,
query_length
*
sizeof
(
uint32_t
)));
OF_CUDA_CHECK
(
GPU
(
Malloc
)
(
&
query_keys_buffer_
,
query_length
*
sizeof
(
Key
)));
max_query_length_
=
query_length
;
}
...
...
@@ -534,18 +683,19 @@ class LruCache : public Cache {
void
*
missing_keys
,
uint32_t
*
missing_indices
)
override
{
CHECK_LE
(
n_keys
,
max_query_length_
);
auto
cuda_stream
=
stream
->
As
<
ep
::
CudaStream
>
();
OF_CUDA_CHECK
(
cuda
MemsetAsync
(
n_missing
,
0
,
sizeof
(
uint32_t
),
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
GPU
(
MemsetAsync
)
(
n_missing
,
0
,
sizeof
(
uint32_t
),
cuda_stream
->
cuda_stream
()));
if
(
n_keys
==
0
)
{
return
;
}
cuda_stream
->
LaunchKernel
(
GetKernel
<
Key
,
Elem
,
true
>
,
GetLaunchConfig
(
n_keys
),
ctx_
,
n_keys
,
static_cast
<
const
Key
*>
(
keys
),
nullptr
,
n_missing
,
static_cast
<
Key
*>
(
missing_keys
),
missing_indices
);
}
using
Cache
::
Get
;
void
Get
(
ep
::
Stream
*
stream
,
uint32_t
n_keys
,
const
void
*
keys
,
void
*
values
,
uint32_t
*
n_missing
,
void
*
missing_keys
,
uint32_t
*
missing_indices
)
override
{
CHECK_LE
(
n_keys
,
max_query_length_
);
auto
cuda_stream
=
stream
->
As
<
ep
::
CudaStream
>
();
OF_CUDA_CHECK
(
cuda
MemsetAsync
(
n_missing
,
0
,
sizeof
(
uint32_t
),
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
GPU
(
MemsetAsync
)
(
n_missing
,
0
,
sizeof
(
uint32_t
),
cuda_stream
->
cuda_stream
()));
if
(
n_keys
==
0
)
{
return
;
}
cuda_stream
->
LaunchKernel
(
GetKernel
<
Key
,
Elem
,
false
>
,
GetLaunchConfig
(
n_keys
),
ctx_
,
n_keys
,
static_cast
<
const
Key
*>
(
keys
),
static_cast
<
Elem
*>
(
values
),
n_missing
,
...
...
@@ -556,7 +706,7 @@ class LruCache : public Cache {
uint32_t
*
n_evicted
,
void
*
evicted_keys
,
void
*
evicted_values
)
override
{
CHECK_LE
(
n_keys
,
max_query_length_
);
auto
cuda_stream
=
stream
->
As
<
ep
::
CudaStream
>
();
OF_CUDA_CHECK
(
cuda
MemsetAsync
(
n_evicted
,
0
,
sizeof
(
uint32_t
),
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
GPU
(
MemsetAsync
)
(
n_evicted
,
0
,
sizeof
(
uint32_t
),
cuda_stream
->
cuda_stream
()));
if
(
n_keys
==
0
)
{
return
;
}
cuda_stream
->
LaunchKernel
(
PutWithoutEvictingKernel
<
Key
,
Elem
>
,
GetLaunchConfig
(
n_keys
),
ctx_
,
n_keys
,
static_cast
<
const
Key
*>
(
keys
),
...
...
@@ -571,7 +721,7 @@ class LruCache : public Cache {
void
Dump
(
ep
::
Stream
*
stream
,
uint64_t
start_key_index
,
uint64_t
end_key_index
,
uint32_t
*
n_dumped
,
void
*
keys
,
void
*
values
)
override
{
auto
cuda_stream
=
stream
->
As
<
ep
::
CudaStream
>
();
OF_CUDA_CHECK
(
cuda
MemsetAsync
(
n_dumped
,
0
,
sizeof
(
uint32_t
),
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
GPU
(
MemsetAsync
)
(
n_dumped
,
0
,
sizeof
(
uint32_t
),
cuda_stream
->
cuda_stream
()));
const
uint64_t
max_dump_keys
=
end_key_index
-
start_key_index
;
cuda_stream
->
LaunchKernel
(
DumpKernel
<
Key
,
Elem
>
,
...
...
@@ -581,6 +731,11 @@ class LruCache : public Cache {
static_cast
<
Elem
*>
(
values
));
}
void
ClearDirtyFlags
()
override
{
// do nothing.
return
;
}
void
Clear
()
override
{
ClearLruCacheContext
<
Key
,
Elem
>
(
&
ctx_
);
}
private:
...
...
oneflow/core/embedding/lru_cache.hip.cpp
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Inspired by https://github.com/NVIDIA-Merlin/HugeCTR/blob/master/gpu_cache/src/nv_gpu_cache.cu
#include "oneflow/core/embedding/lru_cache.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/embedding/hash_functions.hip.h"
#include <new>
#include <hip/hip_runtime.h>
namespace
oneflow
{
namespace
embedding
{
namespace
{
constexpr
int
kWarpSize
=
64
;
constexpr
int
kNumWarpPerBlock
=
2
;
constexpr
int
kBlockSize
=
kNumWarpPerBlock
*
kWarpSize
;
constexpr
unsigned
long
long
int
kFullMask
=
0xFFFFFFFFFFFFFFFFU
;
ep
::
CudaLaunchConfig
GetLaunchConfig
(
uint32_t
n_keys
)
{
return
ep
::
CudaLaunchConfig
((
n_keys
+
kNumWarpPerBlock
-
1
)
/
kNumWarpPerBlock
,
kWarpSize
*
kNumWarpPerBlock
,
0
);
}
struct
ThreadContext
{
__device__
ThreadContext
()
{
const
uint32_t
global_thread_id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
global_warp_id
=
global_thread_id
/
kWarpSize
;
warp_id_in_block
=
global_warp_id
%
kNumWarpPerBlock
;
// NOLINT
num_warps
=
gridDim
.
x
*
kNumWarpPerBlock
;
// NOLINT
lane_id
=
global_thread_id
%
kWarpSize
;
}
uint32_t
global_warp_id
;
uint32_t
warp_id_in_block
;
uint32_t
num_warps
;
uint32_t
lane_id
;
};
class
WarpMutexAtomicImpl
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
WarpMutexAtomicImpl
);
__device__
WarpMutexAtomicImpl
()
:
flag_
(
0
)
{}
__device__
~
WarpMutexAtomicImpl
()
=
default
;
__device__
void
Lock
(
const
ThreadContext
&
thread_ctx
)
{
if
(
thread_ctx
.
lane_id
==
0
)
{
while
(
atomicCAS
(
&
flag_
,
0
,
1
)
!=
0
)
;
}
__threadfence
();
__syncthreads
();
}
__device__
void
Unlock
(
const
ThreadContext
&
thread_ctx
)
{
__syncthreads
();
__threadfence
();
if
(
thread_ctx
.
lane_id
==
0
)
{
atomicExch
(
&
flag_
,
0
);
}
}
private:
int32_t
flag_
;
};
template
<
typename
Key
,
typename
Elem
>
struct
LruCacheContext
{
Key
*
keys
;
Elem
*
lines
;
uint8_t
*
ages
;
void
*
mutex
;
uint64_t
n_set
;
uint32_t
line_size
;
CacheOptions
::
MemoryKind
value_memory_kind
;
};
__global__
void
InitCacheSetMutex
(
uint32_t
n_set
,
void
*
mutex
)
{
using
WarpMutex
=
WarpMutexAtomicImpl
;
const
uint32_t
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
idx
<
n_set
)
{
new
(
reinterpret_cast
<
WarpMutex
*>
(
mutex
)
+
idx
)
WarpMutex
;
}
}
template
<
typename
Key
,
typename
Elem
>
void
ClearLruCacheContext
(
LruCacheContext
<
Key
,
Elem
>*
ctx
)
{
OF_CUDA_CHECK
(
hipMemset
(
ctx
->
keys
,
0
,
ctx
->
n_set
*
kWarpSize
*
sizeof
(
Key
)));
OF_CUDA_CHECK
(
hipMemset
(
ctx
->
ages
,
0
,
ctx
->
n_set
*
kWarpSize
*
sizeof
(
uint8_t
)));
InitCacheSetMutex
<<<
(
ctx
->
n_set
-
1
+
256
)
/
256
,
256
>>>
(
ctx
->
n_set
,
ctx
->
mutex
);
}
template
<
typename
Key
,
typename
Elem
>
void
InitLruCacheContext
(
const
CacheOptions
&
options
,
LruCacheContext
<
Key
,
Elem
>*
ctx
)
{
const
size_t
keys_size_per_set
=
kWarpSize
*
sizeof
(
Key
);
const
uint32_t
line_size
=
options
.
value_size
/
sizeof
(
Elem
);
const
size_t
lines_size_per_set
=
kWarpSize
*
line_size
*
sizeof
(
Elem
);
const
size_t
ages_size_per_set
=
kWarpSize
*
sizeof
(
uint8_t
);
int
device
=
0
;
OF_CUDA_CHECK
(
hipGetDevice
(
&
device
));
int
major
=
0
;
OF_CUDA_CHECK
(
hipDeviceGetAttribute
(
&
major
,
hipDeviceAttributeComputeCapabilityMajor
,
device
));
size_t
mutex_size_per_set
=
0
;
mutex_size_per_set
=
sizeof
(
WarpMutexAtomicImpl
);
const
size_t
n_set
=
(
options
.
capacity
-
1
+
kWarpSize
)
/
kWarpSize
;
CHECK_GT
(
n_set
,
0
);
ctx
->
n_set
=
n_set
;
ctx
->
line_size
=
line_size
;
const
size_t
keys_size
=
n_set
*
keys_size_per_set
;
OF_CUDA_CHECK
(
hipMalloc
(
&
(
ctx
->
keys
),
keys_size
));
const
size_t
lines_size
=
n_set
*
lines_size_per_set
;
if
(
options
.
value_memory_kind
==
CacheOptions
::
MemoryKind
::
kDevice
)
{
OF_CUDA_CHECK
(
hipMalloc
(
&
(
ctx
->
lines
),
lines_size
));
}
else
if
(
options
.
value_memory_kind
==
CacheOptions
::
MemoryKind
::
kHost
)
{
if
(
ParseBooleanFromEnv
(
"ONEFLOW_ONE_EMBEDDING_DISABLE_NUMA_AWARE_ALLOCATION"
,
false
))
{
OF_CUDA_CHECK
(
hipMallocHost
(
reinterpret_cast
<
void
**>
(
&
(
ctx
->
lines
)),
lines_size
));
}
else
{
OF_CUDA_CHECK
(
NumaAwareCudaMallocHost
(
device
,
reinterpret_cast
<
void
**>
(
&
ctx
->
lines
),
lines_size
));
}
}
else
{
UNIMPLEMENTED
();
}
ctx
->
value_memory_kind
=
options
.
value_memory_kind
;
const
size_t
ages_size
=
n_set
*
ages_size_per_set
;
OF_CUDA_CHECK
(
hipMalloc
(
&
(
ctx
->
ages
),
ages_size
));
const
size_t
mutex_size
=
n_set
*
mutex_size_per_set
;
OF_CUDA_CHECK
(
hipMalloc
(
&
(
ctx
->
mutex
),
mutex_size
));
ClearLruCacheContext
(
ctx
);
}
template
<
typename
Key
,
typename
Elem
>
void
DestroyLruCacheContext
(
LruCacheContext
<
Key
,
Elem
>*
ctx
)
{
OF_CUDA_CHECK
(
hipFree
(
ctx
->
keys
));
if
(
ctx
->
value_memory_kind
==
CacheOptions
::
MemoryKind
::
kDevice
)
{
OF_CUDA_CHECK
(
hipFree
(
ctx
->
lines
));
}
else
if
(
ctx
->
value_memory_kind
==
CacheOptions
::
MemoryKind
::
kHost
)
{
OF_CUDA_CHECK
(
hipHostFree
(
ctx
->
lines
));
}
else
{
UNIMPLEMENTED
();
}
OF_CUDA_CHECK
(
hipFree
(
ctx
->
ages
));
OF_CUDA_CHECK
(
hipFree
(
ctx
->
mutex
));
}
template
<
typename
Key
,
typename
Elem
>
struct
SetContext
{
using
WarpMutex
=
WarpMutexAtomicImpl
;
__device__
SetContext
(
const
LruCacheContext
<
Key
,
Elem
>&
ctx
,
uint32_t
set_id
)
:
keys
(
ctx
.
keys
+
set_id
*
kWarpSize
),
mutex
(
reinterpret_cast
<
WarpMutex
*>
(
ctx
.
mutex
)
+
set_id
),
ages
(
ctx
.
ages
+
set_id
*
kWarpSize
),
lines
(
ctx
.
lines
+
set_id
*
kWarpSize
*
ctx
.
line_size
)
{}
__device__
int
Lookup
(
const
ThreadContext
&
thread_ctx
,
Key
key
)
{
const
Key
lane_key
=
keys
[
thread_ctx
.
lane_id
];
const
int
lane_age
=
ages
[
thread_ctx
.
lane_id
];
const
bool
lane_hit
=
(
lane_key
==
key
&&
lane_age
!=
0
);
const
unsigned
long
long
int
hit_mask
=
__ballot
(
lane_hit
);
if
(
hit_mask
!=
0
)
{
return
__ffs
(
static_cast
<
int
>
(
hit_mask
))
-
1
;
}
else
{
return
-
1
;
}
}
__device__
void
Read
(
const
LruCacheContext
<
Key
,
Elem
>&
cache_ctx
,
const
ThreadContext
&
thread_ctx
,
int
way
,
Elem
*
line
)
{
const
Elem
*
from_line
=
lines
+
way
*
cache_ctx
.
line_size
;
for
(
int
i
=
thread_ctx
.
lane_id
;
i
<
cache_ctx
.
line_size
;
i
+=
kWarpSize
)
{
line
[
i
]
=
from_line
[
i
];
}
}
__device__
int
InsertWithoutEvicting
(
const
LruCacheContext
<
Key
,
Elem
>&
cache_ctx
,
const
ThreadContext
&
thread_ctx
,
Key
key
)
{
int
insert_way
=
-
1
;
const
Key
lane_key
=
keys
[
thread_ctx
.
lane_id
];
int
lane_age
=
ages
[
thread_ctx
.
lane_id
];
const
unsigned
long
long
int
hit_mask
=
__ballot
(
lane_key
==
key
&&
lane_age
!=
0
);
if
(
hit_mask
!=
0
)
{
insert_way
=
__ffs
(
static_cast
<
int
>
(
hit_mask
))
-
1
;
const
int
insert_way_age
=
__shfl
(
lane_age
,
insert_way
);
if
(
lane_age
>
insert_way_age
)
{
lane_age
-=
1
;
}
else
if
(
thread_ctx
.
lane_id
==
insert_way
)
{
lane_age
=
kWarpSize
;
}
__syncthreads
();
}
if
(
insert_way
==
-
1
)
{
const
unsigned
long
long
int
valid_mask
=
__ballot
(
lane_age
!=
0
);
if
(
valid_mask
!=
kFullMask
)
{
insert_way
=
__popc
(
static_cast
<
int
>
(
valid_mask
));
if
(
lane_age
>
0
)
{
lane_age
-=
1
;
}
else
if
(
thread_ctx
.
lane_id
==
insert_way
)
{
lane_age
=
kWarpSize
;
keys
[
insert_way
]
=
key
;
}
__syncthreads
();
}
}
if
(
insert_way
!=
-
1
)
{
ages
[
thread_ctx
.
lane_id
]
=
lane_age
;
}
return
insert_way
;
}
__device__
void
Evict
(
const
LruCacheContext
<
Key
,
Elem
>&
cache_ctx
,
const
ThreadContext
&
thread_ctx
,
Key
key
,
int
*
way
,
Key
*
evicted_key
)
{
const
Key
lane_key
=
keys
[
thread_ctx
.
lane_id
];
int
lane_age
=
ages
[
thread_ctx
.
lane_id
];
const
int
insert_way
=
__ffs
(
static_cast
<
int
>
(
__ballot
(
lane_age
==
1
)))
-
1
;
*
evicted_key
=
__shfl
(
lane_key
,
insert_way
);
if
(
thread_ctx
.
lane_id
==
insert_way
)
{
keys
[
insert_way
]
=
key
;
lane_age
=
kWarpSize
;
}
else
if
(
lane_age
>
1
)
{
lane_age
-=
1
;
}
__syncthreads
();
ages
[
thread_ctx
.
lane_id
]
=
lane_age
;
*
way
=
insert_way
;
}
__device__
void
Write
(
const
LruCacheContext
<
Key
,
Elem
>&
cache_ctx
,
const
ThreadContext
&
thread_ctx
,
int
way
,
const
Elem
*
line
)
{
Elem
*
to_line
=
lines
+
way
*
cache_ctx
.
line_size
;
for
(
int
i
=
thread_ctx
.
lane_id
;
i
<
cache_ctx
.
line_size
;
i
+=
kWarpSize
)
{
to_line
[
i
]
=
line
[
i
];
}
}
__device__
void
Lock
(
const
ThreadContext
&
thread_ctx
)
{
mutex
->
Lock
(
thread_ctx
);
}
__device__
void
Unlock
(
const
ThreadContext
&
thread_ctx
)
{
mutex
->
Unlock
(
thread_ctx
);
}
Key
*
keys
;
Elem
*
lines
;
uint8_t
*
ages
;
WarpMutex
*
mutex
;
};
template
<
typename
Key
,
typename
Elem
,
bool
test_only
>
__global__
void
GetKernel
(
LruCacheContext
<
Key
,
Elem
>
cache_ctx
,
uint32_t
num_keys
,
const
Key
*
keys
,
Elem
*
values
,
uint32_t
*
n_missing_keys
,
Key
*
missing_keys
,
uint32_t
*
missing_indices
)
{
ThreadContext
thread_ctx
{};
__shared__
Key
block_keys
[
kNumWarpPerBlock
][
kWarpSize
];
__shared__
size_t
block_set_ids
[
kNumWarpPerBlock
][
kWarpSize
];
for
(
uint32_t
batch_offset
=
thread_ctx
.
global_warp_id
*
kWarpSize
;
batch_offset
<
num_keys
;
batch_offset
+=
thread_ctx
.
num_warps
*
kWarpSize
)
{
const
uint32_t
n_batch_keys
=
min
(
kWarpSize
,
num_keys
-
batch_offset
);
if
(
thread_ctx
.
lane_id
<
n_batch_keys
)
{
const
Key
key
=
keys
[
batch_offset
+
thread_ctx
.
lane_id
];
const
size_t
hash
=
LruCacheHash
()(
key
);
const
uint32_t
set_id
=
hash
%
cache_ctx
.
n_set
;
block_keys
[
thread_ctx
.
warp_id_in_block
][
thread_ctx
.
lane_id
]
=
key
;
block_set_ids
[
thread_ctx
.
warp_id_in_block
][
thread_ctx
.
lane_id
]
=
set_id
;
}
__syncthreads
();
uint32_t
n_warp_missing
=
0
;
Key
warp_missing_key
=
0
;
uint32_t
warp_missing_index
=
0
;
for
(
uint32_t
i
=
0
;
i
<
n_batch_keys
;
++
i
)
{
const
uint32_t
key_idx
=
batch_offset
+
i
;
const
Key
key
=
block_keys
[
thread_ctx
.
warp_id_in_block
][
i
];
const
size_t
set_id
=
block_set_ids
[
thread_ctx
.
warp_id_in_block
][
i
];
SetContext
<
Key
,
Elem
>
set_ctx
(
cache_ctx
,
set_id
);
const
int
way
=
set_ctx
.
Lookup
(
thread_ctx
,
key
);
if
(
way
<
0
)
{
if
(
thread_ctx
.
lane_id
==
n_warp_missing
)
{
warp_missing_key
=
key
;
warp_missing_index
=
key_idx
;
}
__syncthreads
();
n_warp_missing
+=
1
;
}
else
if
(
!
test_only
)
{
set_ctx
.
Read
(
cache_ctx
,
thread_ctx
,
way
,
values
+
key_idx
*
cache_ctx
.
line_size
);
}
}
if
(
n_warp_missing
>
0
)
{
uint32_t
base_missing_idx
=
0
;
if
(
thread_ctx
.
lane_id
==
0
)
{
base_missing_idx
=
atomicAdd
(
n_missing_keys
,
n_warp_missing
);
}
__syncthreads
();
base_missing_idx
=
__shfl
(
base_missing_idx
,
0
);
if
(
thread_ctx
.
lane_id
<
n_warp_missing
)
{
missing_keys
[
base_missing_idx
+
thread_ctx
.
lane_id
]
=
warp_missing_key
;
missing_indices
[
base_missing_idx
+
thread_ctx
.
lane_id
]
=
warp_missing_index
;
}
__syncthreads
();
}
__syncthreads
();
}
}
template
<
typename
Key
,
typename
Elem
>
__global__
void
PutWithoutEvictingKernel
(
LruCacheContext
<
Key
,
Elem
>
cache_ctx
,
uint32_t
num_keys
,
const
Key
*
keys
,
const
Elem
*
values
,
uint32_t
*
n_missing
,
Key
*
missing_keys
,
uint32_t
*
missing_indices
)
{
ThreadContext
thread_ctx
{};
__shared__
Key
block_keys
[
kNumWarpPerBlock
][
kWarpSize
];
__shared__
size_t
block_set_ids
[
kNumWarpPerBlock
][
kWarpSize
];
for
(
uint32_t
batch_offset
=
thread_ctx
.
global_warp_id
*
kWarpSize
;
batch_offset
<
num_keys
;
batch_offset
+=
thread_ctx
.
num_warps
*
kWarpSize
)
{
const
uint32_t
n_batch_keys
=
min
(
kWarpSize
,
num_keys
-
batch_offset
);
if
(
thread_ctx
.
lane_id
<
n_batch_keys
)
{
const
Key
key
=
keys
[
batch_offset
+
thread_ctx
.
lane_id
];
const
size_t
hash
=
LruCacheHash
()(
key
);
const
uint32_t
set_id
=
hash
%
cache_ctx
.
n_set
;
block_keys
[
thread_ctx
.
warp_id_in_block
][
thread_ctx
.
lane_id
]
=
key
;
block_set_ids
[
thread_ctx
.
warp_id_in_block
][
thread_ctx
.
lane_id
]
=
set_id
;
}
__syncthreads
();
uint32_t
n_warp_missing
=
0
;
Key
warp_missing_key
=
0
;
uint32_t
warp_missing_index
=
0
;
for
(
uint32_t
i
=
0
;
i
<
n_batch_keys
;
++
i
)
{
const
uint32_t
key_idx
=
batch_offset
+
i
;
const
Key
key
=
block_keys
[
thread_ctx
.
warp_id_in_block
][
i
];
const
size_t
set_id
=
block_set_ids
[
thread_ctx
.
warp_id_in_block
][
i
];
SetContext
<
Key
,
Elem
>
set_ctx
(
cache_ctx
,
set_id
);
set_ctx
.
Lock
(
thread_ctx
);
Key
evicted_key
=
0
;
const
int
insert_way
=
set_ctx
.
InsertWithoutEvicting
(
cache_ctx
,
thread_ctx
,
key
);
if
(
insert_way
>=
0
)
{
set_ctx
.
Write
(
cache_ctx
,
thread_ctx
,
insert_way
,
values
+
cache_ctx
.
line_size
*
key_idx
);
}
else
{
if
(
thread_ctx
.
lane_id
==
n_warp_missing
)
{
warp_missing_key
=
key
;
warp_missing_index
=
key_idx
;
}
__syncthreads
();
n_warp_missing
+=
1
;
}
set_ctx
.
Unlock
(
thread_ctx
);
}
if
(
n_warp_missing
>
0
)
{
uint32_t
base_missing_idx
=
0
;
if
(
thread_ctx
.
lane_id
==
0
)
{
base_missing_idx
=
atomicAdd
(
n_missing
,
n_warp_missing
);
}
__syncthreads
();
base_missing_idx
=
__shfl
(
base_missing_idx
,
0
);
if
(
thread_ctx
.
lane_id
<
n_warp_missing
)
{
missing_keys
[
base_missing_idx
+
thread_ctx
.
lane_id
]
=
warp_missing_key
;
missing_indices
[
base_missing_idx
+
thread_ctx
.
lane_id
]
=
warp_missing_index
;
}
__syncthreads
();
}
}
}
template
<
typename
Key
,
typename
Elem
>
__global__
void
EvictKernel
(
LruCacheContext
<
Key
,
Elem
>
cache_ctx
,
const
Key
*
keys
,
const
uint32_t
*
indices
,
const
Elem
*
values
,
const
uint32_t
*
n_evict
,
Key
*
evicted_keys
,
Elem
*
evicted_values
)
{
ThreadContext
thread_ctx
{};
uint32_t
num_evict
=
*
n_evict
;
__shared__
Key
block_keys
[
kNumWarpPerBlock
][
kWarpSize
];
__shared__
size_t
block_set_ids
[
kNumWarpPerBlock
][
kWarpSize
];
for
(
uint32_t
batch_offset
=
thread_ctx
.
global_warp_id
*
kWarpSize
;
batch_offset
<
num_evict
;
batch_offset
+=
thread_ctx
.
num_warps
*
kWarpSize
)
{
const
uint32_t
n_batch_keys
=
min
(
kWarpSize
,
num_evict
-
batch_offset
);
if
(
thread_ctx
.
lane_id
<
n_batch_keys
)
{
const
Key
key
=
keys
[
batch_offset
+
thread_ctx
.
lane_id
];
const
size_t
hash
=
LruCacheHash
()(
key
);
const
uint32_t
set_id
=
hash
%
cache_ctx
.
n_set
;
block_keys
[
thread_ctx
.
warp_id_in_block
][
thread_ctx
.
lane_id
]
=
key
;
block_set_ids
[
thread_ctx
.
warp_id_in_block
][
thread_ctx
.
lane_id
]
=
set_id
;
}
__syncthreads
();
for
(
uint32_t
i
=
0
;
i
<
n_batch_keys
;
++
i
)
{
const
uint32_t
key_idx
=
batch_offset
+
i
;
const
Key
key
=
block_keys
[
thread_ctx
.
warp_id_in_block
][
i
];
const
uint32_t
set_id
=
block_set_ids
[
thread_ctx
.
warp_id_in_block
][
i
];
SetContext
<
Key
,
Elem
>
set_ctx
(
cache_ctx
,
set_id
);
set_ctx
.
Lock
(
thread_ctx
);
int
evicted_way
=
-
1
;
Key
evicted_key
=
0
;
set_ctx
.
Evict
(
cache_ctx
,
thread_ctx
,
key
,
&
evicted_way
,
&
evicted_key
);
if
(
thread_ctx
.
lane_id
==
0
)
{
evicted_keys
[
key_idx
]
=
evicted_key
;
}
__syncthreads
();
set_ctx
.
Read
(
cache_ctx
,
thread_ctx
,
evicted_way
,
evicted_values
+
cache_ctx
.
line_size
*
key_idx
);
set_ctx
.
Write
(
cache_ctx
,
thread_ctx
,
evicted_way
,
values
+
cache_ctx
.
line_size
*
indices
[
key_idx
]);
set_ctx
.
Unlock
(
thread_ctx
);
}
}
}
template
<
typename
Key
,
typename
Elem
>
__global__
void
DumpKernel
(
LruCacheContext
<
Key
,
Elem
>
cache_ctx
,
size_t
start_key_index
,
size_t
end_key_index
,
uint32_t
*
n_dumped
,
Key
*
keys
,
Elem
*
values
)
{
ThreadContext
thread_ctx
{};
__shared__
Key
warp_keys
[
kNumWarpPerBlock
][
kWarpSize
];
__shared__
uint8_t
warp_ages
[
kNumWarpPerBlock
][
kWarpSize
];
for
(
uint32_t
warp_start_key_index
=
start_key_index
+
thread_ctx
.
global_warp_id
*
kWarpSize
;
warp_start_key_index
<
end_key_index
;
warp_start_key_index
+=
thread_ctx
.
num_warps
*
kWarpSize
)
{
Key
lane_key
=
0
;
uint8_t
lane_age
=
0
;
if
(
warp_start_key_index
+
thread_ctx
.
lane_id
<
end_key_index
)
{
lane_key
=
cache_ctx
.
keys
[
warp_start_key_index
+
thread_ctx
.
lane_id
];
lane_age
=
cache_ctx
.
ages
[
warp_start_key_index
+
thread_ctx
.
lane_id
];
}
__syncthreads
();
warp_keys
[
thread_ctx
.
warp_id_in_block
][
thread_ctx
.
lane_id
]
=
lane_key
;
warp_ages
[
thread_ctx
.
warp_id_in_block
][
thread_ctx
.
lane_id
]
=
lane_age
;
const
int
key_count
=
__popc
(
static_cast
<
int
>
(
__ballot
(
lane_age
!=
0
)));
if
(
key_count
==
0
)
{
continue
;
}
uint32_t
offset
=
0
;
if
(
thread_ctx
.
lane_id
==
0
)
{
offset
=
atomicAdd
(
n_dumped
,
key_count
);
}
offset
=
__shfl
(
offset
,
0
);
__syncthreads
();
for
(
uint32_t
i
=
0
;
i
<
kWarpSize
;
++
i
)
{
const
Key
key
=
warp_keys
[
thread_ctx
.
warp_id_in_block
][
i
];
const
Key
age
=
warp_ages
[
thread_ctx
.
warp_id_in_block
][
i
];
if
(
age
==
0
)
{
continue
;
}
if
(
thread_ctx
.
lane_id
==
0
)
{
keys
[
offset
]
=
key
;
}
__syncthreads
();
for
(
uint32_t
j
=
thread_ctx
.
lane_id
;
j
<
cache_ctx
.
line_size
;
j
+=
kWarpSize
)
{
values
[
offset
*
cache_ctx
.
line_size
+
j
]
=
cache_ctx
.
lines
[(
warp_start_key_index
+
i
)
*
cache_ctx
.
line_size
+
j
];
}
__syncthreads
();
offset
+=
1
;
}
}
}
template
<
typename
Key
,
typename
Elem
>
class
LruCache
:
public
Cache
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
LruCache
);
explicit
LruCache
(
const
CacheOptions
&
options
)
:
device_index_
{},
max_query_length_
(
0
),
query_indices_buffer_
(
nullptr
),
query_keys_buffer_
(
nullptr
),
value_type_
(
options
.
value_type
)
{
OF_CUDA_CHECK
(
hipGetDevice
(
&
device_index_
));
InitLruCacheContext
(
options
,
&
ctx_
);
}
~
LruCache
()
override
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
if
(
max_query_length_
!=
0
)
{
OF_CUDA_CHECK
(
hipFree
(
query_indices_buffer_
));
OF_CUDA_CHECK
(
hipFree
(
query_keys_buffer_
));
}
DestroyLruCacheContext
(
&
ctx_
);
}
uint32_t
KeySize
()
const
override
{
return
sizeof
(
Key
);
}
uint32_t
ValueSize
()
const
override
{
return
sizeof
(
Elem
)
*
ctx_
.
line_size
;
}
DataType
ValueType
()
const
override
{
return
value_type_
;
}
uint64_t
Capacity
()
const
override
{
return
ctx_
.
n_set
*
kWarpSize
;
}
uint32_t
MaxQueryLength
()
const
override
{
return
max_query_length_
;
}
void
ReserveQueryLength
(
uint32_t
query_length
)
override
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
if
(
query_length
<
max_query_length_
)
{
return
;
}
if
(
max_query_length_
!=
0
)
{
OF_CUDA_CHECK
(
hipFree
(
query_indices_buffer_
));
OF_CUDA_CHECK
(
hipFree
(
query_keys_buffer_
));
}
OF_CUDA_CHECK
(
hipMalloc
(
&
query_indices_buffer_
,
query_length
*
sizeof
(
uint32_t
)));
OF_CUDA_CHECK
(
hipMalloc
(
&
query_keys_buffer_
,
query_length
*
sizeof
(
Key
)));
max_query_length_
=
query_length
;
}
CacheOptions
::
Policy
Policy
()
const
override
{
return
CacheOptions
::
Policy
::
kLRU
;
}
void
Test
(
ep
::
Stream
*
stream
,
uint32_t
n_keys
,
const
void
*
keys
,
uint32_t
*
n_missing
,
void
*
missing_keys
,
uint32_t
*
missing_indices
)
override
{
CHECK_LE
(
n_keys
,
max_query_length_
);
auto
cuda_stream
=
stream
->
As
<
ep
::
CudaStream
>
();
OF_CUDA_CHECK
(
hipMemsetAsync
(
n_missing
,
0
,
sizeof
(
uint32_t
),
cuda_stream
->
cuda_stream
()));
if
(
n_keys
==
0
)
{
return
;
}
cuda_stream
->
LaunchKernel
(
GetKernel
<
Key
,
Elem
,
true
>
,
GetLaunchConfig
(
n_keys
),
ctx_
,
n_keys
,
static_cast
<
const
Key
*>
(
keys
),
nullptr
,
n_missing
,
static_cast
<
Key
*>
(
missing_keys
),
missing_indices
);
}
void
Get
(
ep
::
Stream
*
stream
,
uint32_t
n_keys
,
const
void
*
keys
,
void
*
values
,
uint32_t
*
n_missing
,
void
*
missing_keys
,
uint32_t
*
missing_indices
)
override
{
CHECK_LE
(
n_keys
,
max_query_length_
);
auto
cuda_stream
=
stream
->
As
<
ep
::
CudaStream
>
();
OF_CUDA_CHECK
(
hipMemsetAsync
(
n_missing
,
0
,
sizeof
(
uint32_t
),
cuda_stream
->
cuda_stream
()));
if
(
n_keys
==
0
)
{
return
;
}
cuda_stream
->
LaunchKernel
(
GetKernel
<
Key
,
Elem
,
false
>
,
GetLaunchConfig
(
n_keys
),
ctx_
,
n_keys
,
static_cast
<
const
Key
*>
(
keys
),
static_cast
<
Elem
*>
(
values
),
n_missing
,
static_cast
<
Key
*>
(
missing_keys
),
missing_indices
);
}
void
Put
(
ep
::
Stream
*
stream
,
uint32_t
n_keys
,
const
void
*
keys
,
const
void
*
values
,
uint32_t
*
n_evicted
,
void
*
evicted_keys
,
void
*
evicted_values
)
override
{
CHECK_LE
(
n_keys
,
max_query_length_
);
auto
cuda_stream
=
stream
->
As
<
ep
::
CudaStream
>
();
OF_CUDA_CHECK
(
hipMemsetAsync
(
n_evicted
,
0
,
sizeof
(
uint32_t
),
cuda_stream
->
cuda_stream
()));
if
(
n_keys
==
0
)
{
return
;
}
cuda_stream
->
LaunchKernel
(
PutWithoutEvictingKernel
<
Key
,
Elem
>
,
GetLaunchConfig
(
n_keys
),
ctx_
,
n_keys
,
static_cast
<
const
Key
*>
(
keys
),
static_cast
<
const
Elem
*>
(
values
),
n_evicted
,
query_keys_buffer_
,
query_indices_buffer_
);
cuda_stream
->
LaunchKernel
(
EvictKernel
<
Key
,
Elem
>
,
GetLaunchConfig
(
n_keys
),
ctx_
,
query_keys_buffer_
,
query_indices_buffer_
,
static_cast
<
const
Elem
*>
(
values
),
n_evicted
,
static_cast
<
Key
*>
(
evicted_keys
),
static_cast
<
Elem
*>
(
evicted_values
));
}
void
Dump
(
ep
::
Stream
*
stream
,
uint64_t
start_key_index
,
uint64_t
end_key_index
,
uint32_t
*
n_dumped
,
void
*
keys
,
void
*
values
)
override
{
auto
cuda_stream
=
stream
->
As
<
ep
::
CudaStream
>
();
OF_CUDA_CHECK
(
hipMemsetAsync
(
n_dumped
,
0
,
sizeof
(
uint32_t
),
cuda_stream
->
cuda_stream
()));
const
uint64_t
max_dump_keys
=
end_key_index
-
start_key_index
;
cuda_stream
->
LaunchKernel
(
DumpKernel
<
Key
,
Elem
>
,
ep
::
CudaLaunchConfig
((
max_dump_keys
+
kNumWarpPerBlock
-
1
)
/
kNumWarpPerBlock
,
kBlockSize
,
0
),
ctx_
,
start_key_index
,
end_key_index
,
n_dumped
,
static_cast
<
Key
*>
(
keys
),
static_cast
<
Elem
*>
(
values
));
}
void
Clear
()
override
{
ClearLruCacheContext
<
Key
,
Elem
>
(
&
ctx_
);
}
private:
int
device_index_
;
uint32_t
max_query_length_
;
LruCacheContext
<
Key
,
Elem
>
ctx_
;
uint32_t
*
query_indices_buffer_
;
Key
*
query_keys_buffer_
;
DataType
value_type_
;
};
template
<
typename
Key
>
std
::
unique_ptr
<
Cache
>
DispatchValueType
(
const
CacheOptions
&
options
)
{
if
(
options
.
value_size
%
sizeof
(
ulonglong2
)
==
0
)
{
return
std
::
unique_ptr
<
Cache
>
(
new
LruCache
<
Key
,
ulonglong2
>
(
options
));
}
else
if
(
options
.
value_size
%
sizeof
(
uint64_t
)
==
0
)
{
return
std
::
unique_ptr
<
Cache
>
(
new
LruCache
<
Key
,
uint64_t
>
(
options
));
}
else
if
(
options
.
value_size
%
sizeof
(
uint32_t
)
==
0
)
{
return
std
::
unique_ptr
<
Cache
>
(
new
LruCache
<
Key
,
uint32_t
>
(
options
));
}
else
if
(
options
.
value_size
%
sizeof
(
uint16_t
)
==
0
)
{
return
std
::
unique_ptr
<
Cache
>
(
new
LruCache
<
Key
,
uint16_t
>
(
options
));
}
else
{
return
std
::
unique_ptr
<
Cache
>
(
new
LruCache
<
Key
,
uint8_t
>
(
options
));
}
}
std
::
unique_ptr
<
Cache
>
DispatchKeyType
(
const
CacheOptions
&
options
)
{
if
(
options
.
key_size
==
sizeof
(
uint32_t
))
{
return
DispatchValueType
<
uint32_t
>
(
options
);
}
else
if
(
options
.
key_size
==
sizeof
(
uint64_t
))
{
return
DispatchValueType
<
uint64_t
>
(
options
);
}
else
{
UNIMPLEMENTED
();
return
nullptr
;
}
}
}
// namespace
std
::
unique_ptr
<
Cache
>
NewLruCache
(
const
CacheOptions
&
options
)
{
return
DispatchKeyType
(
options
);
}
}
// namespace embedding
}
// namespace oneflow
\ No newline at end of file
oneflow/core/embedding/mock_key_value_store.cu
View file @
a715222c
...
...
@@ -50,14 +50,14 @@ class IteratorImpl : public KVIterator {
std
::
memcpy
(
reinterpret_cast
<
char
*>
(
host_values_buffer_
)
+
*
host_num_buffer_
*
value_size_
,
pos_
->
second
.
data
(),
value_size_
);
}
OF_CUDA_CHECK
(
cuda
MemcpyAsync
(
n_result
,
host_num_buffer_
,
sizeof
(
uint32_t
),
cuda
MemcpyDefault
,
OF_CUDA_CHECK
(
GPU
(
MemcpyAsync
)
(
n_result
,
host_num_buffer_
,
sizeof
(
uint32_t
),
GPU
(
MemcpyDefault
)
,
cuda_stream
->
cuda_stream
()));
const
uint32_t
num_keys
=
*
host_num_buffer_
;
if
(
num_keys
!=
0
)
{
OF_CUDA_CHECK
(
cuda
MemcpyAsync
(
keys
,
host_keys_buffer_
,
num_keys
*
key_size_
,
cuda
MemcpyDefault
,
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
cuda
MemcpyAsync
(
values
,
host_values_buffer_
,
num_keys
*
value_size_
,
cuda
MemcpyDefault
,
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
GPU
(
MemcpyAsync
)
(
keys
,
host_keys_buffer_
,
num_keys
*
key_size_
,
GPU
(
MemcpyDefault
)
,
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
GPU
(
MemcpyAsync
)
(
values
,
host_values_buffer_
,
num_keys
*
value_size_
,
GPU
(
MemcpyDefault
)
,
cuda_stream
->
cuda_stream
()));
}
}
...
...
@@ -80,7 +80,7 @@ class KeyValueStoreImpl : public KeyValueStore {
OF_DISALLOW_COPY_AND_MOVE
(
KeyValueStoreImpl
);
explicit
KeyValueStoreImpl
(
const
MockKeyValueStoreOptions
&
options
)
:
device_index_
(
-
1
),
max_query_length_
(
0
)
{
OF_CUDA_CHECK
(
cuda
GetDevice
(
&
device_index_
));
OF_CUDA_CHECK
(
GPU
(
GetDevice
)
(
&
device_index_
));
key_size_
=
options
.
key_size
;
value_size_
=
options
.
value_size
;
OF_CUDA_CHECK
(
NumaAwareCudaMallocHost
(
...
...
@@ -97,11 +97,11 @@ class KeyValueStoreImpl : public KeyValueStore {
~
KeyValueStoreImpl
()
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
if
(
max_query_length_
!=
0
)
{
OF_CUDA_CHECK
(
cuda
FreeHost
(
host_query_keys_
));
OF_CUDA_CHECK
(
cuda
FreeHost
(
host_query_values_
));
OF_CUDA_CHECK
(
cuda
FreeHost
(
host_missing_indices_
));
OF_CUDA_CHECK
(
GPU
(
FreeHost
)
(
host_query_keys_
));
OF_CUDA_CHECK
(
GPU
(
FreeHost
)
(
host_query_values_
));
OF_CUDA_CHECK
(
GPU
(
FreeHost
)
(
host_missing_indices_
));
}
OF_CUDA_CHECK
(
cuda
FreeHost
(
host_n_missing_
));
OF_CUDA_CHECK
(
GPU
(
FreeHost
)
(
host_n_missing_
));
}
uint32_t
KeySize
()
const
override
{
return
key_size_
;
}
...
...
@@ -114,9 +114,9 @@ class KeyValueStoreImpl : public KeyValueStore {
CudaCurrentDeviceGuard
guard
(
device_index_
);
if
(
query_length
<=
max_query_length_
)
{
return
;
}
if
(
max_query_length_
!=
0
)
{
OF_CUDA_CHECK
(
cuda
FreeHost
(
host_query_keys_
));
OF_CUDA_CHECK
(
cuda
FreeHost
(
host_query_values_
));
OF_CUDA_CHECK
(
cuda
FreeHost
(
host_missing_indices_
));
OF_CUDA_CHECK
(
GPU
(
FreeHost
)
(
host_query_keys_
));
OF_CUDA_CHECK
(
GPU
(
FreeHost
)
(
host_query_values_
));
OF_CUDA_CHECK
(
GPU
(
FreeHost
)
(
host_missing_indices_
));
}
OF_CUDA_CHECK
(
NumaAwareCudaMallocHost
(
device_index_
,
reinterpret_cast
<
void
**>
(
&
host_query_keys_
),
key_size_
*
query_length
));
...
...
@@ -128,6 +128,7 @@ class KeyValueStoreImpl : public KeyValueStore {
max_query_length_
=
query_length
;
}
using
KeyValueStore
::
Get
;
void
Get
(
ep
::
Stream
*
stream
,
uint32_t
num_keys
,
const
void
*
keys
,
void
*
values
,
uint32_t
*
n_missing
,
uint32_t
*
missing_indices
)
override
;
void
Put
(
ep
::
Stream
*
stream
,
uint32_t
num_keys
,
const
void
*
keys
,
const
void
*
values
)
override
;
...
...
@@ -158,11 +159,11 @@ void KeyValueStoreImpl<Key>::Get(ep::Stream* stream, uint32_t num_keys, const vo
auto
cuda_stream
=
stream
->
As
<
ep
::
CudaStream
>
();
CHECK_LE
(
num_keys
,
max_query_length_
);
if
(
num_keys
==
0
)
{
OF_CUDA_CHECK
(
cuda
MemsetAsync
(
n_missing
,
0
,
sizeof
(
uint32_t
),
OF_CUDA_CHECK
(
GPU
(
MemsetAsync
)
(
n_missing
,
0
,
sizeof
(
uint32_t
),
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()));
return
;
}
OF_CUDA_CHECK
(
cuda
MemcpyAsync
(
host_query_keys_
,
keys
,
key_size_
*
num_keys
,
cuda
MemcpyDefault
,
OF_CUDA_CHECK
(
GPU
(
MemcpyAsync
)
(
host_query_keys_
,
keys
,
key_size_
*
num_keys
,
GPU
(
MemcpyDefault
)
,
cuda_stream
->
cuda_stream
()));
CHECK_JUST
(
cuda_stream
->
Sync
());
*
host_n_missing_
=
0
;
...
...
@@ -175,12 +176,12 @@ void KeyValueStoreImpl<Key>::Get(ep::Stream* stream, uint32_t num_keys, const vo
*
host_n_missing_
+=
1
;
}
}
OF_CUDA_CHECK
(
cuda
MemcpyAsync
(
values
,
host_query_values_
,
num_keys
*
value_size_
,
cuda
MemcpyDefault
,
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
cuda
MemcpyAsync
(
n_missing
,
host_n_missing_
,
sizeof
(
uint32_t
),
cuda
MemcpyDefault
,
OF_CUDA_CHECK
(
GPU
(
MemcpyAsync
)
(
values
,
host_query_values_
,
num_keys
*
value_size_
,
GPU
(
MemcpyDefault
)
,
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
GPU
(
MemcpyAsync
)
(
n_missing
,
host_n_missing_
,
sizeof
(
uint32_t
),
GPU
(
MemcpyDefault
)
,
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
cuda
MemcpyAsync
(
missing_indices
,
host_missing_indices_
,
(
*
host_n_missing_
)
*
sizeof
(
uint32_t
),
cuda
MemcpyDefault
,
OF_CUDA_CHECK
(
GPU
(
MemcpyAsync
)
(
missing_indices
,
host_missing_indices_
,
(
*
host_n_missing_
)
*
sizeof
(
uint32_t
),
GPU
(
MemcpyDefault
)
,
cuda_stream
->
cuda_stream
()));
}
...
...
@@ -191,10 +192,10 @@ void KeyValueStoreImpl<Key>::Put(ep::Stream* stream, uint32_t num_keys, const vo
auto
cuda_stream
=
stream
->
As
<
ep
::
CudaStream
>
();
CHECK_LE
(
num_keys
,
max_query_length_
);
if
(
num_keys
==
0
)
{
return
;
}
OF_CUDA_CHECK
(
cuda
MemcpyAsync
(
host_query_keys_
,
keys
,
key_size_
*
num_keys
,
cuda
MemcpyDefault
,
OF_CUDA_CHECK
(
GPU
(
MemcpyAsync
)
(
host_query_keys_
,
keys
,
key_size_
*
num_keys
,
GPU
(
MemcpyDefault
)
,
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
cuda
MemcpyAsync
(
host_query_values_
,
values
,
value_size_
*
num_keys
,
cuda
MemcpyDefault
,
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
GPU
(
MemcpyAsync
)
(
host_query_values_
,
values
,
value_size_
*
num_keys
,
GPU
(
MemcpyDefault
)
,
cuda_stream
->
cuda_stream
()));
CHECK_JUST
(
cuda_stream
->
Sync
());
for
(
uint32_t
i
=
0
;
i
<
num_keys
;
++
i
)
{
store_
[
host_query_keys_
[
i
]]
=
std
::
string
(
...
...
oneflow/core/embedding/mock_key_value_store.h
View file @
a715222c
...
...
@@ -22,17 +22,7 @@ namespace oneflow {
namespace
embedding
{
#ifdef WITH_CUDA
struct
MockKeyValueStoreOptions
{
uint32_t
key_size
=
0
;
uint32_t
value_size
=
0
;
};
std
::
unique_ptr
<
KeyValueStore
>
NewMockKeyValueStore
(
const
MockKeyValueStoreOptions
&
options
);
#endif // WITH_CUDA
#ifdef WITH_ROCM
#if defined(WITH_CUDA) || defined(WITH_ROCM)
struct
MockKeyValueStoreOptions
{
uint32_t
key_size
=
0
;
...
...
oneflow/core/embedding/mock_key_value_store.hip.cpp
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/embedding/mock_key_value_store.h"
#include "oneflow/core/device/cuda_util.h"
namespace
oneflow
{
namespace
embedding
{
namespace
{
template
<
typename
Key
>
class
IteratorImpl
:
public
KVIterator
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
IteratorImpl
);
IteratorImpl
(
HashMap
<
Key
,
std
::
string
>*
store
,
uint32_t
key_size
,
uint32_t
value_size
,
uint32_t
max_query_length
,
void
*
host_keys_buffer
,
void
*
host_values_buffer
,
uint32_t
*
host_num_buffer
)
:
store_
(
store
),
pos_
(
store
->
begin
()),
key_size_
(
key_size
),
value_size_
(
value_size
),
max_query_length_
(
max_query_length
),
host_keys_buffer_
(
host_keys_buffer
),
host_values_buffer_
(
host_values_buffer
),
host_num_buffer_
(
host_num_buffer
)
{}
~
IteratorImpl
()
override
=
default
;
void
NextN
(
ep
::
Stream
*
stream
,
uint32_t
n_request
,
uint32_t
*
n_result
,
void
*
keys
,
void
*
values
)
override
{
CHECK_LE
(
n_request
,
max_query_length_
);
auto
cuda_stream
=
stream
->
As
<
ep
::
CudaStream
>
();
CHECK_JUST
(
cuda_stream
->
Sync
());
*
host_num_buffer_
=
0
;
while
(
*
host_num_buffer_
<
n_request
&&
pos_
!=
store_
->
end
())
{
reinterpret_cast
<
Key
*>
(
host_keys_buffer_
)[
*
host_num_buffer_
]
=
pos_
->
first
;
std
::
memcpy
(
reinterpret_cast
<
char
*>
(
host_values_buffer_
)
+
*
host_num_buffer_
*
value_size_
,
pos_
->
second
.
data
(),
value_size_
);
}
OF_CUDA_CHECK
(
hipMemcpyAsync
(
n_result
,
host_num_buffer_
,
sizeof
(
uint32_t
),
hipMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
const
uint32_t
num_keys
=
*
host_num_buffer_
;
if
(
num_keys
!=
0
)
{
OF_CUDA_CHECK
(
hipMemcpyAsync
(
keys
,
host_keys_buffer_
,
num_keys
*
key_size_
,
hipMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
hipMemcpyAsync
(
values
,
host_values_buffer_
,
num_keys
*
value_size_
,
hipMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
}
}
void
Reset
()
override
{
pos_
=
store_
->
begin
();
}
private:
HashMap
<
Key
,
std
::
string
>*
store_
;
typename
HashMap
<
Key
,
std
::
string
>::
iterator
pos_
;
uint32_t
key_size_
;
uint32_t
value_size_
;
uint32_t
max_query_length_
;
void
*
host_keys_buffer_
;
void
*
host_values_buffer_
;
uint32_t
*
host_num_buffer_
;
};
template
<
typename
Key
>
class
KeyValueStoreImpl
:
public
KeyValueStore
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
KeyValueStoreImpl
);
explicit
KeyValueStoreImpl
(
const
MockKeyValueStoreOptions
&
options
)
:
device_index_
(
-
1
),
max_query_length_
(
0
)
{
OF_CUDA_CHECK
(
hipGetDevice
(
&
device_index_
));
key_size_
=
options
.
key_size
;
value_size_
=
options
.
value_size
;
OF_CUDA_CHECK
(
NumaAwareCudaMallocHost
(
device_index_
,
reinterpret_cast
<
void
**>
(
&
host_query_keys_
),
key_size_
*
max_query_length_
));
OF_CUDA_CHECK
(
NumaAwareCudaMallocHost
(
device_index_
,
reinterpret_cast
<
void
**>
(
&
host_query_values_
),
value_size_
*
max_query_length_
));
OF_CUDA_CHECK
(
NumaAwareCudaMallocHost
(
device_index_
,
reinterpret_cast
<
void
**>
(
&
host_n_missing_
),
sizeof
(
uint32_t
)));
OF_CUDA_CHECK
(
NumaAwareCudaMallocHost
(
device_index_
,
reinterpret_cast
<
void
**>
(
&
host_missing_indices_
),
sizeof
(
uint32_t
)
*
max_query_length_
));
}
~
KeyValueStoreImpl
()
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
if
(
max_query_length_
!=
0
)
{
OF_CUDA_CHECK
(
hipHostFree
(
host_query_keys_
));
OF_CUDA_CHECK
(
hipHostFree
(
host_query_values_
));
OF_CUDA_CHECK
(
hipHostFree
(
host_missing_indices_
));
}
OF_CUDA_CHECK
(
hipHostFree
(
host_n_missing_
));
}
uint32_t
KeySize
()
const
override
{
return
key_size_
;
}
uint32_t
ValueSize
()
const
override
{
return
value_size_
;
}
uint32_t
MaxQueryLength
()
const
override
{
return
max_query_length_
;
}
void
ReserveQueryLength
(
uint32_t
query_length
)
override
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
if
(
query_length
<=
max_query_length_
)
{
return
;
}
if
(
max_query_length_
!=
0
)
{
OF_CUDA_CHECK
(
hipHostFree
(
host_query_keys_
));
OF_CUDA_CHECK
(
hipHostFree
(
host_query_values_
));
OF_CUDA_CHECK
(
hipHostFree
(
host_missing_indices_
));
}
OF_CUDA_CHECK
(
NumaAwareCudaMallocHost
(
device_index_
,
reinterpret_cast
<
void
**>
(
&
host_query_keys_
),
key_size_
*
query_length
));
OF_CUDA_CHECK
(
NumaAwareCudaMallocHost
(
device_index_
,
reinterpret_cast
<
void
**>
(
&
host_query_values_
),
value_size_
*
query_length
));
OF_CUDA_CHECK
(
NumaAwareCudaMallocHost
(
device_index_
,
reinterpret_cast
<
void
**>
(
&
host_missing_indices_
),
sizeof
(
uint32_t
)
*
query_length
));
max_query_length_
=
query_length
;
}
void
Get
(
ep
::
Stream
*
stream
,
uint32_t
num_keys
,
const
void
*
keys
,
void
*
values
,
uint32_t
*
n_missing
,
uint32_t
*
missing_indices
)
override
;
void
Put
(
ep
::
Stream
*
stream
,
uint32_t
num_keys
,
const
void
*
keys
,
const
void
*
values
)
override
;
bool
SnapshotExists
(
const
std
::
string
&
name
)
override
;
void
LoadSnapshot
(
const
std
::
string
&
name
)
override
;
void
LoadSnapshot
(
const
std
::
string
&
name
,
const
std
::
function
<
void
(
KVIterator
*
iter
)
>&
Hook
)
override
;
void
SaveSnapshot
(
const
std
::
string
&
name
)
override
;
private:
int
device_index_
;
uint32_t
max_query_length_
;
uint32_t
key_size_
;
uint32_t
value_size_
;
Key
*
host_query_keys_
{};
uint8_t
*
host_query_values_
{};
uint32_t
*
host_n_missing_
{};
uint32_t
*
host_missing_indices_
{};
HashMap
<
Key
,
std
::
string
>
store_
;
HashMap
<
std
::
string
,
HashMap
<
Key
,
std
::
string
>>
snapshots_
;
std
::
mutex
mutex_
;
};
template
<
typename
Key
>
void
KeyValueStoreImpl
<
Key
>::
Get
(
ep
::
Stream
*
stream
,
uint32_t
num_keys
,
const
void
*
keys
,
void
*
values
,
uint32_t
*
n_missing
,
uint32_t
*
missing_indices
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
auto
cuda_stream
=
stream
->
As
<
ep
::
CudaStream
>
();
CHECK_LE
(
num_keys
,
max_query_length_
);
if
(
num_keys
==
0
)
{
OF_CUDA_CHECK
(
hipMemsetAsync
(
n_missing
,
0
,
sizeof
(
uint32_t
),
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()));
return
;
}
OF_CUDA_CHECK
(
hipMemcpyAsync
(
host_query_keys_
,
keys
,
key_size_
*
num_keys
,
hipMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
CHECK_JUST
(
cuda_stream
->
Sync
());
*
host_n_missing_
=
0
;
for
(
uint32_t
i
=
0
;
i
<
num_keys
;
++
i
)
{
auto
it
=
store_
.
find
(
host_query_keys_
[
i
]);
if
(
it
!=
store_
.
end
())
{
std
::
memcpy
(
host_query_values_
+
i
*
value_size_
,
it
->
second
.
data
(),
value_size_
);
}
else
{
host_missing_indices_
[
*
host_n_missing_
]
=
i
;
*
host_n_missing_
+=
1
;
}
}
OF_CUDA_CHECK
(
hipMemcpyAsync
(
values
,
host_query_values_
,
num_keys
*
value_size_
,
hipMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
hipMemcpyAsync
(
n_missing
,
host_n_missing_
,
sizeof
(
uint32_t
),
hipMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
hipMemcpyAsync
(
missing_indices
,
host_missing_indices_
,
(
*
host_n_missing_
)
*
sizeof
(
uint32_t
),
hipMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
}
template
<
typename
Key
>
void
KeyValueStoreImpl
<
Key
>::
Put
(
ep
::
Stream
*
stream
,
uint32_t
num_keys
,
const
void
*
keys
,
const
void
*
values
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
auto
cuda_stream
=
stream
->
As
<
ep
::
CudaStream
>
();
CHECK_LE
(
num_keys
,
max_query_length_
);
if
(
num_keys
==
0
)
{
return
;
}
OF_CUDA_CHECK
(
hipMemcpyAsync
(
host_query_keys_
,
keys
,
key_size_
*
num_keys
,
hipMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
hipMemcpyAsync
(
host_query_values_
,
values
,
value_size_
*
num_keys
,
hipMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
CHECK_JUST
(
cuda_stream
->
Sync
());
for
(
uint32_t
i
=
0
;
i
<
num_keys
;
++
i
)
{
store_
[
host_query_keys_
[
i
]]
=
std
::
string
(
reinterpret_cast
<
const
char
*>
(
host_query_values_
)
+
i
*
value_size_
,
value_size_
);
}
}
template
<
typename
Key
>
bool
KeyValueStoreImpl
<
Key
>::
SnapshotExists
(
const
std
::
string
&
name
)
{
return
snapshots_
.
find
(
name
)
!=
snapshots_
.
end
();
}
template
<
typename
Key
>
void
KeyValueStoreImpl
<
Key
>::
LoadSnapshot
(
const
std
::
string
&
name
)
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
LoadSnapshot
(
name
,
nullptr
);
}
template
<
typename
Key
>
void
KeyValueStoreImpl
<
Key
>::
LoadSnapshot
(
const
std
::
string
&
name
,
const
std
::
function
<
void
(
KVIterator
*
iter
)
>&
Hook
)
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
store_
=
snapshots_
[
name
];
if
(
Hook
)
{
IteratorImpl
<
Key
>
iterator
(
&
store_
,
KeySize
(),
ValueSize
(),
max_query_length_
,
host_query_keys_
,
host_query_values_
,
host_n_missing_
);
Hook
(
&
iterator
);
}
}
template
<
typename
Key
>
void
KeyValueStoreImpl
<
Key
>::
SaveSnapshot
(
const
std
::
string
&
name
)
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
snapshots_
[
name
]
=
store_
;
}
}
// namespace
std
::
unique_ptr
<
KeyValueStore
>
NewMockKeyValueStore
(
const
MockKeyValueStoreOptions
&
options
)
{
if
(
options
.
key_size
==
sizeof
(
uint64_t
))
{
return
std
::
unique_ptr
<
KeyValueStore
>
(
new
KeyValueStoreImpl
<
uint64_t
>
(
options
));
}
else
if
(
options
.
key_size
==
sizeof
(
uint32_t
))
{
return
std
::
unique_ptr
<
KeyValueStore
>
(
new
KeyValueStoreImpl
<
uint32_t
>
(
options
));
}
else
{
UNIMPLEMENTED
();
return
nullptr
;
}
}
}
// namespace embedding
}
// namespace oneflow
\ No newline at end of file
oneflow/core/embedding/persistent_table.cpp
View file @
a715222c
...
...
@@ -395,6 +395,7 @@ class PersistentTableImpl : public PersistentTable {
PosixFile
writable_key_file_
;
uint64_t
writable_key_file_chunk_id_
;
PosixFileLockGuard
lock_
;
bool
read_only_
;
};
template
<
typename
Key
,
typename
Engine
>
...
...
@@ -405,14 +406,19 @@ PersistentTableImpl<Key, Engine>::PersistentTableImpl(const PersistentTableOptio
physical_block_size_
(
options
.
physical_block_size
),
logical_block_size_
(
GetLogicalBlockSize
(
options
.
physical_block_size
,
value_size_
)),
blocks_buffer_
(
options
.
physical_block_size
),
writable_key_file_chunk_id_
(
-
1
)
{
writable_key_file_chunk_id_
(
-
1
),
read_only_
(
options
.
read_only
)
{
const
uint64_t
capacity_hint
=
ParseIntegerFromEnv
(
"ONEFLOW_ONE_EMBEDDING_PERSISTENT_TABLE_CAPACITY_HINT"
,
options
.
capacity_hint
);
if
(
capacity_hint
>
0
)
{
row_id_mapping_
.
reserve
(
capacity_hint
);
}
PosixFile
::
RecursiveCreateDirectory
(
options
.
path
,
0755
);
const
std
::
string
lock_filename
=
PosixFile
::
JoinPath
(
options
.
path
,
kLockFileName
);
const
bool
init
=
!
PosixFile
::
FileExists
(
lock_filename
);
lock_
=
PosixFileLockGuard
(
PosixFile
(
lock_filename
,
O_CREAT
|
O_RDWR
,
0644
));
if
(
read_only_
)
{
CHECK
(
!
init
)
<<
"The table must be initialized in read only mode"
;
}
else
{
lock_
=
PosixFileLockGuard
(
PosixFile
(
lock_filename
,
O_CREAT
|
O_RDWR
,
0644
));
}
const
uint64_t
target_chunk_size
=
options
.
target_chunk_size_mb
*
1024
*
1024
;
CHECK_GE
(
target_chunk_size
,
logical_block_size_
);
num_logical_blocks_per_chunk_
=
target_chunk_size
/
logical_block_size_
,
...
...
@@ -442,7 +448,8 @@ PersistentTableImpl<Key, Engine>::PersistentTableImpl(const PersistentTableOptio
for
(
auto
&
chunk
:
chunks
)
{
if
(
value_files_
.
size
()
<=
chunk
.
first
)
{
value_files_
.
resize
(
chunk
.
first
+
1
);
}
CHECK_EQ
(
value_files_
.
at
(
chunk
.
first
).
fd
(),
-
1
);
PosixFile
value_file
(
chunk
.
second
,
O_RDWR
|
O_DIRECT
,
0644
);
const
int
flags
=
read_only_
?
(
O_RDONLY
|
O_DIRECT
)
:
(
O_RDWR
|
O_DIRECT
);
PosixFile
value_file
(
chunk
.
second
,
flags
,
0644
);
value_files_
.
at
(
chunk
.
first
)
=
std
::
move
(
value_file
);
}
if
(
!
value_files_
.
empty
())
{
...
...
@@ -523,6 +530,7 @@ void PersistentTableImpl<Key, Engine>::Get(uint32_t num_keys, const void* keys,
template
<
typename
Key
,
typename
Engine
>
void
PersistentTableImpl
<
Key
,
Engine
>::
PutBlocks
(
uint32_t
num_keys
,
const
void
*
keys
,
const
void
*
blocks
)
{
CHECK
(
!
read_only_
);
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
(
mutex_
);
const
uint32_t
num_blocks
=
RoundUp
(
num_keys
,
num_values_per_block_
)
/
num_values_per_block_
;
const
uint32_t
num_padded_keys
=
num_blocks
*
num_values_per_block_
;
...
...
@@ -579,6 +587,7 @@ void PersistentTableImpl<Key, Engine>::PutBlocks(uint32_t num_keys, const void*
template
<
typename
Key
,
typename
Engine
>
void
PersistentTableImpl
<
Key
,
Engine
>::
Put
(
uint32_t
num_keys
,
const
void
*
keys
,
const
void
*
values
)
{
CHECK
(
!
read_only_
);
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
(
mutex_
);
const
void
*
blocks_ptr
=
nullptr
;
if
(
value_size_
==
logical_block_size_
...
...
@@ -656,6 +665,7 @@ void PersistentTableImpl<Key, Engine>::LoadSnapshotImpl(const std::string& name)
template
<
typename
Key
,
typename
Engine
>
void
PersistentTableImpl
<
Key
,
Engine
>::
SaveSnapshotImpl
(
const
std
::
string
&
name
)
{
CHECK
(
!
read_only_
);
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
(
mutex_
);
PosixFile
::
RecursiveCreateDirectory
(
SnapshotDirPath
(
name
),
0755
);
std
::
ofstream
list_ofs
(
SnapshotListFilePath
(
name
));
...
...
@@ -704,13 +714,11 @@ template<typename Key, typename Engine>
void
PersistentTableImpl
<
Key
,
Engine
>::
LoadSnapshot
(
const
std
::
string
&
name
,
const
std
::
function
<
void
(
Iterator
*
iter
)
>&
Hook
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
(
mutex_
);
int
mmap_flags
=
MAP_SHARED
;
if
(
ParseBooleanFromEnv
(
"ONEFLOW_ONE_EMBEDDING_PERSISTENT_TABLE_SNAPSHOT_LOAD_MAP_POPULATE"
,
true
))
{
mmap_flags
|=
MAP_POPULATE
;
}
const
std
::
string
snapshot_base
=
SnapshotDirPath
(
name
);
const
std
::
string
snapshot_list
=
SnapshotListFilePath
(
name
);
row_id_mapping_
.
clear
();
...
...
@@ -723,10 +731,8 @@ void PersistentTableImpl<Key, Engine>::LoadSnapshot(
CHECK_EQ
(
index_file_size
%
sizeof
(
uint64_t
),
0
);
if
(
index_file_size
==
0
)
{
return
;
}
const
size_t
n_entries
=
index_file_size
/
sizeof
(
uint64_t
);
// PosixMappedFile mapped_index(std::move(index_file), index_file_size, PROT_READ);
PosixMappedFile
mapped_index
(
std
::
move
(
index_file
),
index_file_size
,
PROT_READ
,
mmap_flags
);
PosixFile
key_file
(
KeyFilePath
(
chunk_id
),
O_RDONLY
,
0644
);
// PosixMappedFile mapped_key(std::move(key_file), key_file.Size(), PROT_READ);
PosixMappedFile
mapped_key
(
std
::
move
(
key_file
),
key_file
.
Size
(),
PROT_READ
,
mmap_flags
);
const
uint64_t
*
indices
=
static_cast
<
const
uint64_t
*>
(
mapped_index
.
ptr
());
const
Key
*
keys
=
static_cast
<
const
Key
*>
(
mapped_key
.
ptr
());
...
...
@@ -737,7 +743,6 @@ void PersistentTableImpl<Key, Engine>::LoadSnapshot(
}
if
(
Hook
)
{
PosixFile
value_file
(
ValueFilePath
(
chunk_id
),
O_RDONLY
,
0644
);
// PosixMappedFile mapped_value(std::move(value_file), value_file.Size(), PROT_READ);
PosixMappedFile
mapped_value
(
std
::
move
(
value_file
),
value_file
.
Size
(),
PROT_READ
,
mmap_flags
);
ChunkIteratorImpl
<
Key
>
chunk_iterator
(
value_size_
,
logical_block_size_
,
num_values_per_block_
,
num_values_per_chunk_
,
chunk_id
,
n_entries
,
keys
,
...
...
oneflow/core/embedding/persistent_table.h
View file @
a715222c
...
...
@@ -29,6 +29,7 @@ struct PersistentTableOptions {
uint64_t
target_chunk_size_mb
=
4
*
1024
;
uint16_t
physical_block_size
=
4096
;
uint64_t
capacity_hint
=
0
;
bool
read_only
=
false
;
};
class
PersistentTable
{
...
...
oneflow/core/embedding/persistent_table_key_value_store.cu
View file @
a715222c
...
...
@@ -49,14 +49,14 @@ class IteratorImpl : public KVIterator {
auto
cuda_stream
=
stream
->
As
<
ep
::
CudaStream
>
();
CHECK_JUST
(
cuda_stream
->
Sync
());
base_iter_
->
Next
(
n_request
,
host_num_buffer_
,
host_keys_buffer_
,
host_values_buffer_
);
OF_CUDA_CHECK
(
cuda
MemcpyAsync
(
n_result
,
host_num_buffer_
,
sizeof
(
uint32_t
),
cuda
MemcpyDefault
,
OF_CUDA_CHECK
(
GPU
(
MemcpyAsync
)
(
n_result
,
host_num_buffer_
,
sizeof
(
uint32_t
),
GPU
(
MemcpyDefault
)
,
cuda_stream
->
cuda_stream
()));
const
uint32_t
num_keys
=
*
host_num_buffer_
;
if
(
num_keys
!=
0
)
{
OF_CUDA_CHECK
(
cuda
MemcpyAsync
(
keys
,
host_keys_buffer_
,
num_keys
*
key_size_
,
cuda
MemcpyDefault
,
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
cuda
MemcpyAsync
(
values
,
host_values_buffer_
,
num_keys
*
value_size_
,
cuda
MemcpyDefault
,
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
GPU
(
MemcpyAsync
)
(
keys
,
host_keys_buffer_
,
num_keys
*
key_size_
,
GPU
(
MemcpyDefault
)
,
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
GPU
(
MemcpyAsync
)
(
values
,
host_values_buffer_
,
num_keys
*
value_size_
,
GPU
(
MemcpyDefault
)
,
cuda_stream
->
cuda_stream
()));
}
}
...
...
@@ -78,7 +78,7 @@ class KeyValueStoreImpl : public KeyValueStore {
OF_DISALLOW_COPY_AND_MOVE
(
KeyValueStoreImpl
);
explicit
KeyValueStoreImpl
(
const
PersistentTableKeyValueStoreOptions
&
options
)
:
device_index_
(
-
1
),
max_query_length_
(
0
)
{
OF_CUDA_CHECK
(
cuda
GetDevice
(
&
device_index_
));
OF_CUDA_CHECK
(
GPU
(
GetDevice
)
(
&
device_index_
));
key_size_
=
options
.
table_options
.
key_size
;
value_size_
=
options
.
table_options
.
value_size
;
table_
=
NewPersistentTable
(
options
.
table_options
);
...
...
@@ -96,11 +96,11 @@ class KeyValueStoreImpl : public KeyValueStore {
~
KeyValueStoreImpl
()
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
if
(
max_query_length_
!=
0
)
{
OF_CUDA_CHECK
(
cuda
FreeHost
(
host_query_keys_
));
OF_CUDA_CHECK
(
cuda
FreeHost
(
host_query_values_
));
OF_CUDA_CHECK
(
cuda
FreeHost
(
host_missing_indices_
));
OF_CUDA_CHECK
(
GPU
(
FreeHost
)
(
host_query_keys_
));
OF_CUDA_CHECK
(
GPU
(
FreeHost
)
(
host_query_values_
));
OF_CUDA_CHECK
(
GPU
(
FreeHost
)
(
host_missing_indices_
));
}
OF_CUDA_CHECK
(
cuda
FreeHost
(
host_n_missing_
));
OF_CUDA_CHECK
(
GPU
(
FreeHost
)
(
host_n_missing_
));
}
uint32_t
KeySize
()
const
override
{
return
key_size_
;
}
...
...
@@ -113,9 +113,9 @@ class KeyValueStoreImpl : public KeyValueStore {
CudaCurrentDeviceGuard
guard
(
device_index_
);
if
(
query_length
<=
max_query_length_
)
{
return
;
}
if
(
max_query_length_
!=
0
)
{
OF_CUDA_CHECK
(
cuda
FreeHost
(
host_query_keys_
));
OF_CUDA_CHECK
(
cuda
FreeHost
(
host_query_values_
));
OF_CUDA_CHECK
(
cuda
FreeHost
(
host_missing_indices_
));
OF_CUDA_CHECK
(
GPU
(
FreeHost
)
(
host_query_keys_
));
OF_CUDA_CHECK
(
GPU
(
FreeHost
)
(
host_query_values_
));
OF_CUDA_CHECK
(
GPU
(
FreeHost
)
(
host_missing_indices_
));
}
OF_CUDA_CHECK
(
NumaAwareCudaMallocHost
(
device_index_
,
reinterpret_cast
<
void
**>
(
&
host_query_keys_
),
key_size_
*
query_length
));
...
...
@@ -127,6 +127,7 @@ class KeyValueStoreImpl : public KeyValueStore {
max_query_length_
=
query_length
;
}
using
KeyValueStore
::
Get
;
void
Get
(
ep
::
Stream
*
stream
,
uint32_t
num_keys
,
const
void
*
keys
,
void
*
values
,
uint32_t
*
n_missing
,
uint32_t
*
missing_indices
)
override
;
void
Put
(
ep
::
Stream
*
stream
,
uint32_t
num_keys
,
const
void
*
keys
,
const
void
*
values
)
override
;
...
...
@@ -157,23 +158,23 @@ void KeyValueStoreImpl<Key>::Get(ep::Stream* stream, uint32_t num_keys, const vo
auto
cuda_stream
=
stream
->
As
<
ep
::
CudaStream
>
();
CHECK_LE
(
num_keys
,
max_query_length_
);
if
(
num_keys
==
0
)
{
OF_CUDA_CHECK
(
cuda
MemsetAsync
(
n_missing
,
0
,
sizeof
(
uint32_t
),
OF_CUDA_CHECK
(
GPU
(
MemsetAsync
)
(
n_missing
,
0
,
sizeof
(
uint32_t
),
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()));
return
;
}
OF_CUDA_CHECK
(
cuda
MemcpyAsync
(
host_query_keys_
,
keys
,
key_size_
*
num_keys
,
cuda
MemcpyDefault
,
OF_CUDA_CHECK
(
GPU
(
MemcpyAsync
)
(
host_query_keys_
,
keys
,
key_size_
*
num_keys
,
GPU
(
MemcpyDefault
)
,
cuda_stream
->
cuda_stream
()));
CHECK_JUST
(
cuda_stream
->
Sync
());
table_
->
Get
(
num_keys
,
host_query_keys_
,
host_query_values_
,
host_n_missing_
,
host_missing_indices_
);
OF_CUDA_CHECK
(
cuda
MemcpyAsync
(
values
,
host_query_values_
,
num_keys
*
value_size_
,
cuda
MemcpyDefault
,
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
cuda
MemcpyAsync
(
n_missing
,
host_n_missing_
,
sizeof
(
uint32_t
),
cuda
MemcpyDefault
,
OF_CUDA_CHECK
(
GPU
(
MemcpyAsync
)
(
values
,
host_query_values_
,
num_keys
*
value_size_
,
GPU
(
MemcpyDefault
)
,
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
GPU
(
MemcpyAsync
)
(
n_missing
,
host_n_missing_
,
sizeof
(
uint32_t
),
GPU
(
MemcpyDefault
)
,
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
cuda
MemcpyAsync
(
missing_indices
,
host_missing_indices_
,
(
*
host_n_missing_
)
*
sizeof
(
uint32_t
),
cuda
MemcpyDefault
,
OF_CUDA_CHECK
(
GPU
(
MemcpyAsync
)
(
missing_indices
,
host_missing_indices_
,
(
*
host_n_missing_
)
*
sizeof
(
uint32_t
),
GPU
(
MemcpyDefault
)
,
cuda_stream
->
cuda_stream
()));
}
...
...
@@ -184,10 +185,10 @@ void KeyValueStoreImpl<Key>::Put(ep::Stream* stream, uint32_t num_keys, const vo
auto
cuda_stream
=
stream
->
As
<
ep
::
CudaStream
>
();
CHECK_LE
(
num_keys
,
max_query_length_
);
if
(
num_keys
==
0
)
{
return
;
}
OF_CUDA_CHECK
(
cuda
MemcpyAsync
(
host_query_keys_
,
keys
,
key_size_
*
num_keys
,
cuda
MemcpyDefault
,
OF_CUDA_CHECK
(
GPU
(
MemcpyAsync
)
(
host_query_keys_
,
keys
,
key_size_
*
num_keys
,
GPU
(
MemcpyDefault
)
,
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
cuda
MemcpyAsync
(
host_query_values_
,
values
,
value_size_
*
num_keys
,
cuda
MemcpyDefault
,
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
GPU
(
MemcpyAsync
)
(
host_query_values_
,
values
,
value_size_
*
num_keys
,
GPU
(
MemcpyDefault
)
,
cuda_stream
->
cuda_stream
()));
CHECK_JUST
(
cuda_stream
->
Sync
());
table_
->
Put
(
num_keys
,
host_query_keys_
,
host_query_values_
);
}
...
...
oneflow/core/embedding/persistent_table_key_value_store.h
View file @
a715222c
...
...
@@ -23,7 +23,7 @@ namespace oneflow {
namespace
embedding
{
#ifdef
WITH_CUDA
#if
def
ined(
WITH_CUDA
) || defined(WITH_ROCM)
struct
PersistentTableKeyValueStoreOptions
{
PersistentTableOptions
table_options
{};
...
...
@@ -33,16 +33,6 @@ std::unique_ptr<KeyValueStore> NewPersistentTableKeyValueStore(
const
PersistentTableKeyValueStoreOptions
&
options
);
#endif // WITH_CUDA
#ifdef WITH_ROCM
struct
PersistentTableKeyValueStoreOptions
{
PersistentTableOptions
table_options
{};
};
std
::
unique_ptr
<
KeyValueStore
>
NewPersistentTableKeyValueStore
(
const
PersistentTableKeyValueStoreOptions
&
options
);
#endif // WITH_ROCM
}
// namespace embedding
...
...
oneflow/core/embedding/persistent_table_key_value_store.hip.cpp
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/embedding/persistent_table_key_value_store.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/embedding/persistent_table.h"
#include <robin_hood.h>
#include <fcntl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <dirent.h>
namespace
oneflow
{
namespace
embedding
{
namespace
{
class
IteratorImpl
:
public
KVIterator
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
IteratorImpl
);
IteratorImpl
(
PersistentTable
::
Iterator
*
base_iter
,
uint32_t
key_size
,
uint32_t
value_size
,
uint32_t
max_query_length
,
void
*
host_keys_buffer
,
void
*
host_values_buffer
,
uint32_t
*
host_num_buffer
)
:
base_iter_
(
base_iter
),
key_size_
(
key_size
),
value_size_
(
value_size
),
max_query_length_
(
max_query_length
),
host_keys_buffer_
(
host_keys_buffer
),
host_values_buffer_
(
host_values_buffer
),
host_num_buffer_
(
host_num_buffer
)
{}
~
IteratorImpl
()
override
=
default
;
void
NextN
(
ep
::
Stream
*
stream
,
uint32_t
n_request
,
uint32_t
*
n_result
,
void
*
keys
,
void
*
values
)
override
{
CHECK_LE
(
n_request
,
max_query_length_
);
auto
cuda_stream
=
stream
->
As
<
ep
::
CudaStream
>
();
CHECK_JUST
(
cuda_stream
->
Sync
());
base_iter_
->
Next
(
n_request
,
host_num_buffer_
,
host_keys_buffer_
,
host_values_buffer_
);
OF_CUDA_CHECK
(
hipMemcpyAsync
(
n_result
,
host_num_buffer_
,
sizeof
(
uint32_t
),
hipMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
const
uint32_t
num_keys
=
*
host_num_buffer_
;
if
(
num_keys
!=
0
)
{
OF_CUDA_CHECK
(
hipMemcpyAsync
(
keys
,
host_keys_buffer_
,
num_keys
*
key_size_
,
hipMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
hipMemcpyAsync
(
values
,
host_values_buffer_
,
num_keys
*
value_size_
,
hipMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
}
}
void
Reset
()
override
{
base_iter_
->
Reset
();
}
private:
PersistentTable
::
Iterator
*
base_iter_
;
uint32_t
key_size_
;
uint32_t
value_size_
;
uint32_t
max_query_length_
;
void
*
host_keys_buffer_
;
void
*
host_values_buffer_
;
uint32_t
*
host_num_buffer_
;
};
template
<
typename
Key
>
class
KeyValueStoreImpl
:
public
KeyValueStore
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
KeyValueStoreImpl
);
explicit
KeyValueStoreImpl
(
const
PersistentTableKeyValueStoreOptions
&
options
)
:
device_index_
(
-
1
),
max_query_length_
(
0
)
{
OF_CUDA_CHECK
(
hipGetDevice
(
&
device_index_
));
key_size_
=
options
.
table_options
.
key_size
;
value_size_
=
options
.
table_options
.
value_size
;
table_
=
NewPersistentTable
(
options
.
table_options
);
OF_CUDA_CHECK
(
NumaAwareCudaMallocHost
(
device_index_
,
reinterpret_cast
<
void
**>
(
&
host_query_keys_
),
key_size_
*
max_query_length_
));
OF_CUDA_CHECK
(
NumaAwareCudaMallocHost
(
device_index_
,
reinterpret_cast
<
void
**>
(
&
host_query_values_
),
value_size_
*
max_query_length_
));
OF_CUDA_CHECK
(
NumaAwareCudaMallocHost
(
device_index_
,
reinterpret_cast
<
void
**>
(
&
host_n_missing_
),
sizeof
(
uint32_t
)));
OF_CUDA_CHECK
(
NumaAwareCudaMallocHost
(
device_index_
,
reinterpret_cast
<
void
**>
(
&
host_missing_indices_
),
sizeof
(
uint32_t
)
*
max_query_length_
));
}
~
KeyValueStoreImpl
()
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
if
(
max_query_length_
!=
0
)
{
OF_CUDA_CHECK
(
hipHostFree
(
host_query_keys_
));
OF_CUDA_CHECK
(
hipHostFree
(
host_query_values_
));
OF_CUDA_CHECK
(
hipHostFree
(
host_missing_indices_
));
}
OF_CUDA_CHECK
(
hipHostFree
(
host_n_missing_
));
}
uint32_t
KeySize
()
const
override
{
return
key_size_
;
}
uint32_t
ValueSize
()
const
override
{
return
value_size_
;
}
uint32_t
MaxQueryLength
()
const
override
{
return
max_query_length_
;
}
void
ReserveQueryLength
(
uint32_t
query_length
)
override
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
if
(
query_length
<=
max_query_length_
)
{
return
;
}
if
(
max_query_length_
!=
0
)
{
OF_CUDA_CHECK
(
hipHostFree
(
host_query_keys_
));
OF_CUDA_CHECK
(
hipHostFree
(
host_query_values_
));
OF_CUDA_CHECK
(
hipHostFree
(
host_missing_indices_
));
}
OF_CUDA_CHECK
(
NumaAwareCudaMallocHost
(
device_index_
,
reinterpret_cast
<
void
**>
(
&
host_query_keys_
),
key_size_
*
query_length
));
OF_CUDA_CHECK
(
NumaAwareCudaMallocHost
(
device_index_
,
reinterpret_cast
<
void
**>
(
&
host_query_values_
),
value_size_
*
query_length
));
OF_CUDA_CHECK
(
NumaAwareCudaMallocHost
(
device_index_
,
reinterpret_cast
<
void
**>
(
&
host_missing_indices_
),
sizeof
(
uint32_t
)
*
query_length
));
max_query_length_
=
query_length
;
}
void
Get
(
ep
::
Stream
*
stream
,
uint32_t
num_keys
,
const
void
*
keys
,
void
*
values
,
uint32_t
*
n_missing
,
uint32_t
*
missing_indices
)
override
;
void
Put
(
ep
::
Stream
*
stream
,
uint32_t
num_keys
,
const
void
*
keys
,
const
void
*
values
)
override
;
bool
SnapshotExists
(
const
std
::
string
&
name
)
override
;
void
LoadSnapshot
(
const
std
::
string
&
name
)
override
;
void
LoadSnapshot
(
const
std
::
string
&
name
,
const
std
::
function
<
void
(
KVIterator
*
iter
)
>&
Hook
)
override
;
void
SaveSnapshot
(
const
std
::
string
&
name
)
override
;
private:
int
device_index_
;
uint32_t
max_query_length_
;
uint32_t
key_size_
;
uint32_t
value_size_
;
Key
*
host_query_keys_
{};
uint8_t
*
host_query_values_
{};
uint32_t
*
host_n_missing_
{};
uint32_t
*
host_missing_indices_
{};
std
::
mutex
mutex_
;
std
::
unique_ptr
<
PersistentTable
>
table_
;
};
template
<
typename
Key
>
void
KeyValueStoreImpl
<
Key
>::
Get
(
ep
::
Stream
*
stream
,
uint32_t
num_keys
,
const
void
*
keys
,
void
*
values
,
uint32_t
*
n_missing
,
uint32_t
*
missing_indices
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
auto
cuda_stream
=
stream
->
As
<
ep
::
CudaStream
>
();
CHECK_LE
(
num_keys
,
max_query_length_
);
if
(
num_keys
==
0
)
{
OF_CUDA_CHECK
(
hipMemsetAsync
(
n_missing
,
0
,
sizeof
(
uint32_t
),
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()));
return
;
}
OF_CUDA_CHECK
(
hipMemcpyAsync
(
host_query_keys_
,
keys
,
key_size_
*
num_keys
,
hipMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
CHECK_JUST
(
cuda_stream
->
Sync
());
table_
->
Get
(
num_keys
,
host_query_keys_
,
host_query_values_
,
host_n_missing_
,
host_missing_indices_
);
OF_CUDA_CHECK
(
hipMemcpyAsync
(
values
,
host_query_values_
,
num_keys
*
value_size_
,
hipMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
hipMemcpyAsync
(
n_missing
,
host_n_missing_
,
sizeof
(
uint32_t
),
hipMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
hipMemcpyAsync
(
missing_indices
,
host_missing_indices_
,
(
*
host_n_missing_
)
*
sizeof
(
uint32_t
),
hipMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
}
template
<
typename
Key
>
void
KeyValueStoreImpl
<
Key
>::
Put
(
ep
::
Stream
*
stream
,
uint32_t
num_keys
,
const
void
*
keys
,
const
void
*
values
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
auto
cuda_stream
=
stream
->
As
<
ep
::
CudaStream
>
();
CHECK_LE
(
num_keys
,
max_query_length_
);
if
(
num_keys
==
0
)
{
return
;
}
OF_CUDA_CHECK
(
hipMemcpyAsync
(
host_query_keys_
,
keys
,
key_size_
*
num_keys
,
hipMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
OF_CUDA_CHECK
(
hipMemcpyAsync
(
host_query_values_
,
values
,
value_size_
*
num_keys
,
hipMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
CHECK_JUST
(
cuda_stream
->
Sync
());
table_
->
Put
(
num_keys
,
host_query_keys_
,
host_query_values_
);
}
template
<
typename
Key
>
bool
KeyValueStoreImpl
<
Key
>::
SnapshotExists
(
const
std
::
string
&
name
)
{
return
table_
->
SnapshotExists
(
name
);
}
template
<
typename
Key
>
void
KeyValueStoreImpl
<
Key
>::
LoadSnapshot
(
const
std
::
string
&
name
)
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
LoadSnapshot
(
name
,
nullptr
);
}
template
<
typename
Key
>
void
KeyValueStoreImpl
<
Key
>::
LoadSnapshot
(
const
std
::
string
&
name
,
const
std
::
function
<
void
(
KVIterator
*
iter
)
>&
Hook
)
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
if
(
Hook
)
{
table_
->
LoadSnapshot
(
name
,
[
&
](
PersistentTable
::
Iterator
*
chunk_iterator
)
{
IteratorImpl
iterator
(
chunk_iterator
,
KeySize
(),
ValueSize
(),
max_query_length_
,
host_query_keys_
,
host_query_values_
,
host_n_missing_
);
Hook
(
&
iterator
);
});
}
else
{
table_
->
LoadSnapshot
(
name
);
}
}
template
<
typename
Key
>
void
KeyValueStoreImpl
<
Key
>::
SaveSnapshot
(
const
std
::
string
&
name
)
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
table_
->
SaveSnapshot
(
name
);
}
}
// namespace
std
::
unique_ptr
<
KeyValueStore
>
NewPersistentTableKeyValueStore
(
const
PersistentTableKeyValueStoreOptions
&
options
)
{
if
(
options
.
table_options
.
key_size
==
sizeof
(
uint64_t
))
{
return
std
::
unique_ptr
<
KeyValueStore
>
(
new
KeyValueStoreImpl
<
uint64_t
>
(
options
));
}
else
if
(
options
.
table_options
.
key_size
==
sizeof
(
uint32_t
))
{
return
std
::
unique_ptr
<
KeyValueStore
>
(
new
KeyValueStoreImpl
<
uint32_t
>
(
options
));
}
else
{
UNIMPLEMENTED
();
return
nullptr
;
}
}
}
// namespace embedding
}
// namespace oneflow
\ No newline at end of file
oneflow/core/embedding/posix_file.h
View file @
a715222c
...
...
@@ -141,15 +141,15 @@ class PosixFile final {
class
PosixMappedFile
final
{
public:
PosixMappedFile
()
:
file_
(),
ptr_
(
nullptr
)
{}
//
PosixMappedFile(PosixFile&& file, size_t size, int prot
) : file_(std::move(file)), ptr_(nullptr) {
PosixMappedFile
(
PosixFile
&&
file
,
size_t
size
,
int
prot
,
int
flags
)
:
file_
(
std
::
move
(
file
)),
ptr_
(
nullptr
)
{
PosixMappedFile
(
PosixFile
&&
file
,
size_t
size
,
int
prot
,
int
flags
)
:
file_
(
std
::
move
(
file
)),
ptr_
(
nullptr
)
{
CHECK_NE
(
file_
.
fd
(),
-
1
);
// void* ptr = mmap(nullptr, size, prot, MAP_SHARED, file_.fd(), 0);
void
*
ptr
=
mmap
(
nullptr
,
size
,
prot
,
flags
,
file_
.
fd
(),
0
);
PCHECK
(
ptr
!=
MAP_FAILED
);
ptr_
=
ptr
;
}
PosixMappedFile
(
PosixFile
&&
file
,
size_t
size
,
int
prot
)
:
PosixMappedFile
(
std
::
move
(
file
),
size
,
prot
,
MAP_SHARED
)
{}
PosixMappedFile
(
PosixFile
&&
file
,
size_t
size
,
int
prot
)
:
PosixMappedFile
(
std
::
move
(
file
),
size
,
prot
,
MAP_SHARED
)
{}
PosixMappedFile
(
PosixMappedFile
&&
other
)
noexcept
:
PosixMappedFile
()
{
*
this
=
std
::
move
(
other
);
}
...
...
oneflow/core/ep/common/primitive/batch_matmul.cpp
View file @
a715222c
...
...
@@ -93,7 +93,6 @@ REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, BatchMatmulFactory,
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
BatchMatmulFactory
,
BatchMatmulFactoryImpl
<
DeviceType
::
kCUDA
>
);
#endif // WITH_CUDA
#ifdef WITH_ROCM
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
BatchMatmulFactory
,
BatchMatmulFactoryImpl
<
DeviceType
::
kCUDA
>
);
...
...
oneflow/core/ep/common/primitive/binary_functor.h
View file @
a715222c
...
...
@@ -19,6 +19,7 @@ limitations under the License.
#include "oneflow/core/ep/include/primitive/binary_op.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/common/scalar.h"
#include <cmath>
namespace
oneflow
{
...
...
@@ -124,6 +125,39 @@ struct BinaryFunctor<device, BinaryOp::kGreaterEqual, Src, Dst> {
OF_DEVICE_FUNC
Dst
operator
()(
Src
src0
,
Src
src1
)
const
{
return
static_cast
<
Dst
>
(
src0
>=
src1
);
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kIsCloseEqualNan
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
atol
(
attr0
.
Value
<
float
>
()),
rtol
(
attr1
.
Value
<
float
>
())
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src0
,
Src
src1
)
const
{
bool
close
=
src0
==
src1
;
close
|=
(
std
::
isnan
(
src0
)
and
std
::
isnan
(
src1
));
if
(
atol
==
0
and
rtol
==
0
)
return
close
;
Src
allowed_error
=
static_cast
<
Src
>
(
atol
)
+
abs
(
static_cast
<
Src
>
(
rtol
)
*
src1
);
Src
actual_error
=
abs
(
src0
-
src1
);
close
|=
(
std
::
isfinite
(
actual_error
)
and
(
actual_error
<=
allowed_error
));
return
close
;
}
float
atol
,
rtol
;
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kIsClose
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
atol
(
attr0
.
Value
<
float
>
()),
rtol
(
attr1
.
Value
<
float
>
())
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src0
,
Src
src1
)
const
{
bool
close
=
src0
==
src1
;
if
(
atol
==
0
and
rtol
==
0
)
return
close
;
Src
allowed_error
=
static_cast
<
Src
>
(
atol
)
+
abs
(
static_cast
<
Src
>
(
rtol
)
*
src1
);
Src
actual_error
=
abs
(
src0
-
src1
);
close
|=
(
std
::
isfinite
(
actual_error
)
and
(
actual_error
<=
allowed_error
));
return
close
;
}
float
atol
,
rtol
;
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kLogicalAnd
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
...
...
@@ -147,6 +181,81 @@ struct BinaryFunctor<device, BinaryOp::kLogicalXor, Src, Dst> {
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kFmod
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src0
,
Src
src1
)
const
{
return
static_cast
<
Dst
>
(
src0
%
src1
);
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kFloorDiv
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src0
,
Src
src1
)
const
{
return
src0
/
src1
;
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kTruncDiv
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src0
,
Src
src1
)
const
{
return
static_cast
<
Dst
>
(
src0
/
src1
);
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kFloorMod
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src0
,
Src
src1
)
const
{
Src
trunc_mod
=
src0
%
src1
;
return
(
trunc_mod
!=
static_cast
<
Src
>
(
0
))
&&
((
src1
<
static_cast
<
Src
>
(
0
))
!=
(
trunc_mod
<
static_cast
<
Src
>
(
0
)))
?
trunc_mod
+
src1
:
trunc_mod
;
}
};
template
<
DeviceType
device
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kFloorMod
,
uint8_t
,
uint8_t
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
uint8_t
operator
()(
uint8_t
src0
,
uint8_t
src1
)
const
{
return
src0
%
src1
;
}
};
template
<
DeviceType
device
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kFloorMod
,
uint32_t
,
uint32_t
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
uint32_t
operator
()(
uint32_t
src0
,
uint32_t
src1
)
const
{
return
src0
%
src1
;
}
};
template
<
DeviceType
device
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kFloorMod
,
uint64_t
,
uint64_t
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
uint64_t
operator
()(
uint64_t
src0
,
uint64_t
src1
)
const
{
return
src0
%
src1
;
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kScalarBasePowerGrad
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
scalar_operand
(
attr0
.
Value
<
Src
>
())
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src0
,
Src
src1
)
const
{
return
scalar_operand
*
(
pow
(
src0
,
scalar_operand
-
static_cast
<
Src
>
(
1
)))
*
src1
;
}
Src
scalar_operand
;
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kScalarExpPowerGrad
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
scalar_operand
(
attr0
.
Value
<
Src
>
())
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src0
,
Src
src1
)
const
{
return
(
pow
(
scalar_operand
,
src0
))
*
log
(
scalar_operand
)
*
src1
;
}
Src
scalar_operand
;
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kEluBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
alpha
(
attr0
.
Value
<
double
>
())
{}
...
...
@@ -314,6 +423,226 @@ struct BinaryFunctor<device, BinaryOp::kThresholdBackwardWithDyX, Src, Dst> {
const
Src
threshold
;
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kAbsBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
const
Src
zero
=
static_cast
<
Src
>
(
0.0
);
if
(
x
==
zero
)
{
return
zero
;
}
else
if
(
x
<
zero
)
{
return
-
dy
;
}
else
{
return
dy
;
}
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kAcosBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
dy
*
-
rsqrt
(
static_cast
<
Src
>
(
1.0
)
-
x
*
x
);
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kAcoshBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
dy
*
rsqrt
(
x
*
x
-
static_cast
<
Src
>
(
1.0
));
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kAsinBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
dy
*
rsqrt
(
static_cast
<
Src
>
(
1.0
)
-
x
*
x
);
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kAsinhBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
dy
*
rsqrt
(
static_cast
<
Src
>
(
1.0
)
+
x
*
x
);
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kAtanBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
const
Src
one
=
static_cast
<
Src
>
(
1.0
);
return
dy
*
(
one
/
(
one
+
x
*
x
));
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kAtanhBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
const
Src
one
=
static_cast
<
Src
>
(
1.0
);
return
dy
*
(
one
/
(
one
-
x
*
x
));
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kCosBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
dy
*
(
-
sin
(
x
));
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kCoshBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
dy
*
sinh
(
x
);
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kErfBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
dy
*
static_cast
<
Src
>
(
2.0
)
*
rsqrt
(
static_cast
<
Src
>
(
M_PI
))
*
exp
(
-
x
*
x
);
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kErfcBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
dy
*
-
static_cast
<
Src
>
(
2.0
)
*
rsqrt
(
static_cast
<
Src
>
(
M_PI
))
*
exp
(
-
x
*
x
);
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kExpBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
dy
*
exp
(
x
);
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kExpm1BackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
dy
*
exp
(
x
);
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kLgammaBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
// TODO(chengcheng): return: dy * digamma(x)
assert
(
false
);
return
0.0
;
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kLogBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
dy
*
(
static_cast
<
Src
>
(
1.0
)
/
x
);
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kLog2BackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
dy
*
(
static_cast
<
Src
>
(
1.0
)
/
(
x
*
log
(
static_cast
<
Src
>
(
2.0
))));
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kLog10BackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
dy
*
(
static_cast
<
Src
>
(
1.0
)
/
(
x
*
log
(
static_cast
<
Src
>
(
10.0
))));
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kLog1pBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
dy
*
(
static_cast
<
Src
>
(
1.0
)
/
(
x
+
static_cast
<
Src
>
(
1.0
)));
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kLogSigmoidBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
dy
*
(
static_cast
<
Src
>
(
1.0
)
/
(
exp
(
x
)
+
static_cast
<
Src
>
(
1.0
)));
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kReciprocalBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
dy
*
(
-
static_cast
<
Src
>
(
1.0
)
/
(
x
*
x
));
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kReciprocalNoNanBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
if
(
abs
(
x
)
<=
static_cast
<
Src
>
(
0.0
))
{
return
static_cast
<
Dst
>
(
0.0
);
}
return
dy
*
(
-
static_cast
<
Src
>
(
1.0
)
/
(
x
*
x
));
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kRsqrtBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
dy
*
(
static_cast
<
Src
>
(
-
1.0
)
/
(
static_cast
<
Src
>
(
2.0
)
*
sqrt
(
x
*
x
*
x
)));
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kSigmoidBackwardWithDyY
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
y
)
const
{
return
dy
*
(
y
*
(
1.0
-
y
));
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kSinBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
dy
*
cos
(
x
);
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kSinhBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
dy
*
cosh
(
x
);
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kSqrtBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
dy
*
static_cast
<
Src
>
(
0.5
)
/
sqrt
(
x
);
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kSquareBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
dy
*
static_cast
<
Src
>
(
2.0
)
*
x
;
}
};
template
<
DeviceType
device
,
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
device
,
BinaryOp
::
kTanBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
const
Src
cos_val
=
cos
(
x
);
return
dy
*
(
static_cast
<
Src
>
(
1.0
)
/
(
cos_val
*
cos_val
));
}
};
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace ep
...
...
Prev
1
…
16
17
18
19
20
21
22
23
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment