Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Oneflow
Commits
a715222c
Commit
a715222c
authored
Feb 28, 2023
by
yuguo
Browse files
0.9.1-rocm
parent
f262efc9
Changes
469
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
190 additions
and
1891 deletions
+190
-1891
oneflow/core/eager/call_context.h
oneflow/core/eager/call_context.h
+18
-25
oneflow/core/eager/critical_section_instruction_type.h
oneflow/core/eager/critical_section_instruction_type.h
+0
-138
oneflow/core/eager/critical_section_phy_instr_operand.cpp
oneflow/core/eager/critical_section_phy_instr_operand.cpp
+0
-117
oneflow/core/eager/critical_section_phy_instr_operand.h
oneflow/core/eager/critical_section_phy_instr_operand.h
+0
-295
oneflow/core/eager/eager_blob_object.cpp
oneflow/core/eager/eager_blob_object.cpp
+54
-30
oneflow/core/eager/eager_blob_object.h
oneflow/core/eager/eager_blob_object.h
+72
-70
oneflow/core/eager/lazy_job_instruction_type.h
oneflow/core/eager/lazy_job_instruction_type.h
+0
-130
oneflow/core/eager/lazy_job_phy_instr_operand.cpp
oneflow/core/eager/lazy_job_phy_instr_operand.cpp
+0
-37
oneflow/core/eager/lazy_job_phy_instr_operand.h
oneflow/core/eager/lazy_job_phy_instr_operand.h
+0
-79
oneflow/core/eager/local_dep_object.h
oneflow/core/eager/local_dep_object.h
+5
-1
oneflow/core/eager/op_call_instruction_type.cpp
oneflow/core/eager/op_call_instruction_type.cpp
+0
-160
oneflow/core/eager/op_call_instruction_type.h
oneflow/core/eager/op_call_instruction_type.h
+0
-46
oneflow/core/eager/op_call_phy_instr_operand.cpp
oneflow/core/eager/op_call_phy_instr_operand.cpp
+0
-107
oneflow/core/eager/op_call_phy_instr_operand.h
oneflow/core/eager/op_call_phy_instr_operand.h
+0
-121
oneflow/core/eager/release_tensor_arg_phy_instr_operand.h
oneflow/core/eager/release_tensor_arg_phy_instr_operand.h
+0
-71
oneflow/core/eager/release_tensor_instruction_type.h
oneflow/core/eager/release_tensor_instruction_type.h
+0
-108
oneflow/core/embedding/cache.h
oneflow/core/embedding/cache.h
+3
-0
oneflow/core/embedding/cache_test.cpp
oneflow/core/embedding/cache_test.cpp
+1
-1
oneflow/core/embedding/cached_key_value_store.cu
oneflow/core/embedding/cached_key_value_store.cu
+37
-29
oneflow/core/embedding/cached_key_value_store.hip.cpp
oneflow/core/embedding/cached_key_value_store.hip.cpp
+0
-326
No files found.
Too many changes to show.
To preserve performance only
469 of 469+
files are displayed.
Plain diff
Email patch
oneflow/core/eager/call_context.h
View file @
a715222c
...
@@ -21,17 +21,14 @@ limitations under the License.
...
@@ -21,17 +21,14 @@ limitations under the License.
#include "oneflow/core/framework/op_interpreter.h"
#include "oneflow/core/framework/op_interpreter.h"
#include "oneflow/core/common/shape_view.h"
#include "oneflow/core/common/shape_view.h"
#include "oneflow/core/common/stride.h"
#include "oneflow/core/common/stride.h"
#include "oneflow/core/common/small_vector.h"
namespace
oneflow
{
namespace
oneflow
{
namespace
one
{
namespace
one
{
class
StatefulLocalOpKernel
;
class
StatefulLocalOpKernel
;
class
ConsistentTensorInferResult
;
class
GlobalTensorInferResult
;
using
EagerBlobObjectList
=
std
::
vector
<
std
::
shared_ptr
<
vm
::
EagerBlobObject
>>
;
using
EagerBlobObjectListPtr
=
std
::
shared_ptr
<
const
std
::
vector
<
std
::
shared_ptr
<
vm
::
EagerBlobObject
>>>
;
}
// namespace one
}
// namespace one
...
@@ -60,10 +57,7 @@ class TmpTensor final : public user_op::Tensor {
...
@@ -60,10 +57,7 @@ class TmpTensor final : public user_op::Tensor {
char
*
mut_tmp_buffer_ptr
()
{
return
tmp_buffer_ptr_
;
}
char
*
mut_tmp_buffer_ptr
()
{
return
tmp_buffer_ptr_
;
}
void
init_tmp_buffer_ptr
(
char
*
ptr
)
{
void
set_tmp_buffer_ptr
(
char
*
ptr
)
{
tmp_buffer_ptr_
=
ptr
;
}
CHECK_EQ
(
tmp_buffer_ptr_
,
nullptr
);
tmp_buffer_ptr_
=
ptr
;
}
private:
private:
std
::
shared_ptr
<
MemoryCase
>
mem_case_
;
std
::
shared_ptr
<
MemoryCase
>
mem_case_
;
...
@@ -73,35 +67,34 @@ class TmpTensor final : public user_op::Tensor {
...
@@ -73,35 +67,34 @@ class TmpTensor final : public user_op::Tensor {
class
CallContext
{
class
CallContext
{
public:
public:
CallContext
(
CallContext
(
ComposedAttrMap
&&
composed_attrs
,
vm
::
EagerBlobObjectList
&&
inputs
,
ComposedAttrMap
&&
composed_attrs
,
const
one
::
EagerBlobObjectList
Ptr
&
in
puts
,
vm
::
EagerBlobObjectList
&&
out
puts
,
const
one
::
EagerBlobObjectListPtr
&
outputs
,
const
std
::
shared_ptr
<
const
one
::
GlobalTensorInferResult
>&
global_tensor_infer_result
,
const
std
::
shared_ptr
<
const
one
::
ConsistentTenso
rIn
f
er
Result
>&
consistent_tensor
_in
f
er
_result
,
const
one
::
OpExp
rIn
t
er
pContext
&
op
_in
t
er
p_ctx
,
const
one
::
OpExprInterpContext
&
op_interp_ctx
,
const
std
::
shared_ptr
<
MemoryCase
>&
mem_case
)
const
std
::
shared_ptr
<
MemoryCase
>&
mem_case
)
:
composed_attrs_
(
std
::
move
(
composed_attrs
)),
:
composed_attrs_
(
std
::
move
(
composed_attrs
)),
inputs_
(
inputs
),
inputs_
(
std
::
move
(
inputs
)
)
,
outputs_
(
outputs
),
outputs_
(
std
::
move
(
outputs
)
)
,
consistent
_tensor_infer_result_
(
consistent
_tensor_infer_result
),
global
_tensor_infer_result_
(
global
_tensor_infer_result
),
op_interp_ctx_
(
op_interp_ctx
),
op_interp_ctx_
(
op_interp_ctx
),
tmp_tensor_
(
mem_case
)
{}
tmp_tensor_
(
mem_case
)
{}
~
CallContext
()
=
default
;
~
CallContext
()
=
default
;
const
ComposedAttrMap
&
composed_attrs
()
const
{
return
composed_attrs_
;
}
const
ComposedAttrMap
&
composed_attrs
()
const
{
return
composed_attrs_
;
}
const
one
::
EagerBlobObjectListPtr
&
inputs
()
const
{
return
inputs_
;
}
const
vm
::
EagerBlobObjectList
&
inputs
()
const
{
return
inputs_
;
}
const
one
::
EagerBlobObjectListPtr
&
outputs
()
const
{
return
outputs_
;
}
const
vm
::
EagerBlobObjectList
&
outputs
()
const
{
return
outputs_
;
}
const
std
::
shared_ptr
<
const
one
::
ConsistentTensorInferResult
>&
consistent_tensor_infer_result
()
const
std
::
shared_ptr
<
const
one
::
GlobalTensorInferResult
>&
global_tensor_infer_result
()
const
{
const
{
return
global_tensor_infer_result_
;
return
consistent_tensor_infer_result_
;
}
}
const
one
::
OpExprInterpContext
&
op_interp_ctx
()
const
{
return
op_interp_ctx_
;
}
const
one
::
OpExprInterpContext
&
op_interp_ctx
()
const
{
return
op_interp_ctx_
;
}
TmpTensor
*
mut_tmp_tensor
()
{
return
&
tmp_tensor_
;
}
TmpTensor
*
mut_tmp_tensor
()
{
return
&
tmp_tensor_
;
}
private:
private:
const
ComposedAttrMap
composed_attrs_
;
const
ComposedAttrMap
composed_attrs_
;
const
one
::
EagerBlobObjectList
Ptr
inputs_
;
const
vm
::
EagerBlobObjectList
inputs_
;
const
one
::
EagerBlobObjectList
Ptr
outputs_
;
const
vm
::
EagerBlobObjectList
outputs_
;
const
std
::
shared_ptr
<
const
one
::
Consistent
TensorInferResult
>
consistent
_tensor_infer_result_
;
const
std
::
shared_ptr
<
const
one
::
Global
TensorInferResult
>
global
_tensor_infer_result_
;
const
one
::
OpExprInterpContext
op_interp_ctx_
;
const
one
::
OpExprInterpContext
op_interp_ctx_
;
TmpTensor
tmp_tensor_
;
TmpTensor
tmp_tensor_
;
};
};
...
...
oneflow/core/eager/critical_section_instruction_type.h
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EAGER_CRITICAL_SECTION_INSTRUCTION_TYPE_H_
#define ONEFLOW_CORE_EAGER_CRITICAL_SECTION_INSTRUCTION_TYPE_H_
#include "oneflow/core/vm/critical_section_status_querier.h"
#include "oneflow/core/eager/critical_section_phy_instr_operand.h"
#include "oneflow/core/job/critical_section_instance.h"
#include "oneflow/core/framework/nn_graph_if.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/vm/instruction.h"
#include "oneflow/core/vm/instruction_type.h"
#include "oneflow/core/common/buffer_manager.h"
#include "oneflow/core/common/singleton.h"
#include "oneflow/core/vm/stream.h"
#include "oneflow/core/vm/thread_ctx.h"
#include "oneflow/core/register/ofblob.h"
#include "oneflow/core/vm/ref_cnt_instruction_status_querier.h"
#include "oneflow/core/profiler/profiler.h"
namespace
oneflow
{
namespace
vm
{
class
CriticalSectionBeginInstructionType
final
:
public
InstructionType
{
public:
CriticalSectionBeginInstructionType
(
const
CriticalSectionBeginInstructionType
&
)
=
delete
;
CriticalSectionBeginInstructionType
(
CriticalSectionBeginInstructionType
&&
)
=
delete
;
CriticalSectionBeginInstructionType
&
operator
=
(
const
CriticalSectionBeginInstructionType
&
)
=
delete
;
CriticalSectionBeginInstructionType
&
operator
=
(
CriticalSectionBeginInstructionType
&&
)
=
delete
;
CriticalSectionBeginInstructionType
()
=
default
;
~
CriticalSectionBeginInstructionType
()
=
default
;
std
::
string
DebugName
(
const
vm
::
Instruction
&
instruction
)
const
override
{
return
"CriticalSectionBegin"
;
}
Maybe
<
void
>
Prepare
(
vm
::
Instruction
*
instruction
)
const
override
{
return
Maybe
<
void
>::
Ok
();
}
void
Compute
(
vm
::
Instruction
*
instruction
)
const
override
{
OF_PROFILER_RANGE_GUARD
(
"CriticalSectionBegin"
);
{
auto
ptr
=
instruction
->
phy_instr_operand
();
auto
phy_instr_operand
=
std
::
dynamic_pointer_cast
<
CriticalSectionBeginPhyInstrOperand
>
(
ptr
);
CHECK_NOTNULL
(
phy_instr_operand
);
const
auto
&
critical_section_instance
=
MakeCriticalSectionInstance
(
phy_instr_operand
);
const
auto
&
job_name
=
critical_section_instance
->
job_name
();
auto
*
buffer_mgr
=
Singleton
<
BufferMgr
<
std
::
shared_ptr
<
CriticalSectionInstance
>>>::
Get
();
for
(
int
i
=
0
;
i
<
phy_instr_operand
->
interfaces_op_names
().
size
();
++
i
)
{
if
(
phy_instr_operand
->
interfaces_valid
().
at
(
i
))
{
const
std
::
string
&
interface_op_name
=
phy_instr_operand
->
interfaces_op_names
().
at
(
i
);
const
auto
&
buffer_name
=
phy_instr_operand
->
GetInterfaceBufferName
(
job_name
,
interface_op_name
);
buffer_mgr
->
Get
(
buffer_name
)
->
Push
(
critical_section_instance
);
}
}
const
auto
&
callback_buffer_name
=
phy_instr_operand
->
GetInterfaceCriticalSectionCallbackBufferName
(
job_name
);
buffer_mgr
->
Get
(
callback_buffer_name
)
->
Push
(
critical_section_instance
);
const
auto
&
wait_buffer_name
=
phy_instr_operand
->
GetInterfaceCriticalSectionWaitBufferName
(
job_name
);
buffer_mgr
->
Get
(
wait_buffer_name
)
->
Push
(
critical_section_instance
);
}
{
auto
*
status_buffer_data
=
instruction
->
mut_status_buffer
()
->
mut_buffer
();
auto
*
status_querier
=
CriticalSectionStatusQuerier
::
MutCast
(
status_buffer_data
);
status_querier
->
SetLaunched
(
std
::
make_shared
<
NaiveEventRecord
>
());
}
}
private:
class
NaiveCriticalSectionInstance
final
:
public
CriticalSectionInstance
{
public:
NaiveCriticalSectionInstance
(
const
std
::
shared_ptr
<
CriticalSectionBeginPhyInstrOperand
>&
phy_instr_operand
,
const
std
::
string
&
job_name
)
:
CriticalSectionInstance
(),
phy_instr_operand_
(
phy_instr_operand
),
job_name_
(
job_name
)
{}
~
NaiveCriticalSectionInstance
()
override
=
default
;
const
std
::
string
&
job_name
()
const
override
{
return
job_name_
;
}
void
AccessBlobByOpName
(
uint64_t
ofblob_ptr
,
const
std
::
string
&
op_name
)
const
override
{
phy_instr_operand_
->
AccessBlobByOpName
(
ofblob_ptr
,
op_name
);
}
void
Finish
()
const
override
{
phy_instr_operand_
->
Finish
();
}
private:
std
::
shared_ptr
<
CriticalSectionBeginPhyInstrOperand
>
phy_instr_operand_
;
std
::
string
job_name_
;
};
std
::
shared_ptr
<
CriticalSectionInstance
>
MakeCriticalSectionInstance
(
const
std
::
shared_ptr
<
CriticalSectionBeginPhyInstrOperand
>&
phy_instr_operand
)
const
{
phy_instr_operand
->
FinishInvalidInterfaceEventRecords
();
const
auto
&
job_name
=
phy_instr_operand
->
nn_graph
()
->
job_name
();
return
std
::
make_shared
<
NaiveCriticalSectionInstance
>
(
phy_instr_operand
,
job_name
);
}
};
class
CriticalSectionEndInstructionType
final
:
public
InstructionType
{
public:
CriticalSectionEndInstructionType
(
const
CriticalSectionEndInstructionType
&
)
=
delete
;
CriticalSectionEndInstructionType
(
CriticalSectionEndInstructionType
&&
)
=
delete
;
CriticalSectionEndInstructionType
&
operator
=
(
const
CriticalSectionEndInstructionType
&
)
=
delete
;
CriticalSectionEndInstructionType
&
operator
=
(
CriticalSectionEndInstructionType
&&
)
=
delete
;
CriticalSectionEndInstructionType
()
=
default
;
~
CriticalSectionEndInstructionType
()
=
default
;
std
::
string
DebugName
(
const
vm
::
Instruction
&
instruction
)
const
override
{
return
"CriticalSectionEnd"
;
}
Maybe
<
void
>
Prepare
(
vm
::
Instruction
*
instruction
)
const
override
{
return
Maybe
<
void
>::
Ok
();
}
void
Compute
(
vm
::
Instruction
*
instruction
)
const
override
{
const
auto
*
ptr
=
instruction
->
phy_instr_operand
().
get
();
const
auto
*
phy_instr_operand
=
dynamic_cast
<
const
CriticalSectionEndPhyInstrOperand
*>
(
ptr
);
CHECK_NOTNULL
(
phy_instr_operand
);
auto
*
status_buffer_data
=
instruction
->
mut_status_buffer
()
->
mut_buffer
();
auto
*
status_querier
=
CriticalSectionStatusQuerier
::
MutCast
(
status_buffer_data
);
status_querier
->
SetLaunched
(
phy_instr_operand
->
event_record
());
}
};
}
// namespace vm
}
// namespace oneflow
#endif // ONEFLOW_CORE_EAGER_CRITICAL_SECTION_INSTRUCTION_TYPE_H_
oneflow/core/eager/critical_section_phy_instr_operand.cpp
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/eager/critical_section_phy_instr_operand.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/stream.h"
#include "oneflow/core/kernel/kernel_util.h"
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/device/device_context.h"
#include "oneflow/core/device/ep_based_event_record.h"
#include "oneflow/core/register/ofblob.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/vm/stream.h"
namespace
oneflow
{
namespace
vm
{
void
CriticalSectionBeginPhyInstrOperand
::
ForEachMirroredObject
(
const
std
::
function
<
void
(
vm
::
MirroredObject
*
compute
)
>&
DoEach
)
const
{
for
(
const
auto
&
eager_blob_object
:
*
eager_blob_objects_
)
{
DoEach
(
CHECK_JUST
(
eager_blob_object
->
compute_local_dep_object
()));
}
}
void
CriticalSectionEndPhyInstrOperand
::
ForEachMirroredObject
(
const
std
::
function
<
void
(
vm
::
MirroredObject
*
compute
)
>&
DoEach
)
const
{
DoEach
(
CHECK_JUST
(
eager_blob_object_
->
compute_local_dep_object
()));
}
void
CriticalSectionBeginPhyInstrOperand
::
ForEachMutMirroredObject
(
const
std
::
function
<
void
(
vm
::
MirroredObject
*
compute
)
>&
DoEach
)
const
{
DoEach
(
vm_stream_
->
schedule_local_dep_object
().
get
());
}
void
CriticalSectionBeginPhyInstrOperand
::
FinishInvalidInterfaceEventRecords
()
{
for
(
const
auto
&
op_name
:
interfaces_op_names
())
{
size_t
index
=
CHECK_JUST
(
MapAt
(
op_name2interface_index_
,
op_name
));
if
(
!
interfaces_valid
().
at
(
index
))
{
const
auto
&
iter
=
op_name2end_event_record_
->
find
(
op_name
);
CHECK
(
iter
!=
op_name2end_event_record_
->
end
());
iter
->
second
->
Init
(
std
::
make_shared
<
NaiveEventRecord
>
());
}
}
}
void
CriticalSectionBeginPhyInstrOperand
::
Finish
()
{
for
(
const
auto
&
pair
:
*
op_name2end_event_record_
)
{
pair
.
second
->
TryInit
(
std
::
make_shared
<
NaiveEventRecord
>
());
}
}
void
InputCriticalSectionBeginPhyInstrOperand
::
AccessBlobByOpName
(
uint64_t
of_blob_ptr
,
const
std
::
string
&
op_name
)
{
int64_t
i
=
CHECK_JUST
(
MapAt
(
op_name2interface_index_
,
op_name
));
CHECK
(
interfaces_valid
().
at
(
i
));
OfBlob
*
of_blob
=
reinterpret_cast
<
OfBlob
*>
(
of_blob_ptr
);
const
auto
&
eager_blob_object
=
eager_blob_objects_
->
at
(
i
);
{
size_t
header_size
=
of_blob
->
mut_blob
()
->
blob_desc
().
ByteSizeOfBlobHeader
();
CHECK_EQ
(
header_size
,
eager_blob_object
->
shape
().
NumAxes
()
*
sizeof
(
int64_t
));
std
::
memcpy
(
of_blob
->
mut_blob
()
->
mut_header_ptr
(),
eager_blob_object
->
mut_header_ptr
(),
header_size
);
}
const
auto
&
end_event_record
=
op_name2end_event_record_
->
at
(
op_name
);
if
(
eager_blob_object
->
dptr
()
==
nullptr
)
{
end_event_record
->
Init
(
std
::
make_shared
<
NaiveEventRecord
>
());
}
else
{
{
const
size_t
body_bytes
=
of_blob
->
blob
().
ByteSizeOfBlobBody
();
CHECK_EQ
(
eager_blob_object
->
ByteSizeOfBlobBody
(),
body_bytes
);
AutoMemcpy
(
of_blob
->
stream
(),
of_blob
->
mut_blob
()
->
mut_dptr
(),
eager_blob_object
->
dptr
(),
body_bytes
,
of_blob
->
blob
().
mem_case
(),
eager_blob_object
->
mem_case
());
}
end_event_record
->
Init
(
EpBasedEventRecord
::
MakeEventRecord
(
of_blob
->
stream
()));
}
}
void
OutputCriticalSectionBeginPhyInstrOperand
::
AccessBlobByOpName
(
uint64_t
of_blob_ptr
,
const
std
::
string
&
op_name
)
{
int64_t
i
=
CHECK_JUST
(
MapAt
(
op_name2interface_index_
,
op_name
));
CHECK
(
interfaces_valid
().
at
(
i
));
OfBlob
*
of_blob
=
reinterpret_cast
<
OfBlob
*>
(
of_blob_ptr
);
auto
&
eager_blob_object
=
eager_blob_objects_
->
at
(
i
);
of_blob
->
blob
().
shape_view
().
ToShape
(
eager_blob_object
->
mut_shape
());
const
auto
&
end_event_record
=
op_name2end_event_record_
->
at
(
op_name
);
if
(
eager_blob_object
->
dptr
()
==
nullptr
)
{
end_event_record
->
Init
(
std
::
make_shared
<
NaiveEventRecord
>
());
}
else
{
{
const
size_t
body_bytes
=
of_blob
->
blob
().
ByteSizeOfBlobBody
();
CHECK_EQ
(
eager_blob_object
->
ByteSizeOfBlobBody
(),
body_bytes
);
AutoMemcpy
(
of_blob
->
stream
(),
eager_blob_object
->
mut_dptr
(),
of_blob
->
blob
().
dptr
(),
body_bytes
,
eager_blob_object
->
mem_case
(),
of_blob
->
blob
().
mem_case
());
}
end_event_record
->
Init
(
EpBasedEventRecord
::
MakeEventRecord
(
of_blob
->
stream
()));
}
}
void
CriticalSectionEndPhyInstrOperand
::
ForEachMutMirroredObject
(
const
std
::
function
<
void
(
vm
::
MirroredObject
*
compute
)
>&
DoEach
)
const
{
DoEach
(
vm_stream_
->
schedule_local_dep_object
().
get
());
}
}
// namespace vm
}
// namespace oneflow
oneflow/core/eager/critical_section_phy_instr_operand.h
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EAGER_CRITICAL_SECTION_PHY_INSTR_OPERAND_H_
#define ONEFLOW_CORE_EAGER_CRITICAL_SECTION_PHY_INSTR_OPERAND_H_
#include "oneflow/core/vm/phy_instr_operand.h"
#include "oneflow/core/eager/eager_blob_object.h"
#include "oneflow/core/device/event_record.h"
#include "oneflow/core/framework/nn_graph_if.h"
#include "oneflow/core/common/buffer_manager.h"
namespace
oneflow
{
namespace
one
{
using
EagerBlobObjectListPtr
=
std
::
shared_ptr
<
const
std
::
vector
<
std
::
shared_ptr
<
vm
::
EagerBlobObject
>>>
;
}
namespace
vm
{
class
Stream
;
class
CriticalSectionBeginPhyInstrOperand
:
public
PhyInstrOperand
{
public:
CriticalSectionBeginPhyInstrOperand
(
const
CriticalSectionBeginPhyInstrOperand
&
)
=
delete
;
CriticalSectionBeginPhyInstrOperand
(
CriticalSectionBeginPhyInstrOperand
&&
)
=
delete
;
CriticalSectionBeginPhyInstrOperand
&
operator
=
(
const
CriticalSectionBeginPhyInstrOperand
&
)
=
delete
;
CriticalSectionBeginPhyInstrOperand
&
operator
=
(
CriticalSectionBeginPhyInstrOperand
&&
)
=
delete
;
virtual
~
CriticalSectionBeginPhyInstrOperand
()
=
default
;
explicit
CriticalSectionBeginPhyInstrOperand
(
const
std
::
shared_ptr
<
NNGraphIf
>&
nn_graph
,
const
one
::
EagerBlobObjectListPtr
&
eager_blob_objects
,
const
std
::
shared_ptr
<
HashMap
<
std
::
string
,
std
::
shared_ptr
<
SharedEventRecord
>>>&
op_name2end_event_record
,
vm
::
Stream
*
vm_stream
)
:
nn_graph_
(
nn_graph
),
eager_blob_objects_
(
eager_blob_objects
),
op_name2end_event_record_
(
op_name2end_event_record
),
vm_stream_
(
vm_stream
)
{}
const
std
::
shared_ptr
<
NNGraphIf
>&
nn_graph
()
const
{
return
nn_graph_
;
}
const
one
::
EagerBlobObjectListPtr
&
eager_blob_objects
()
const
{
return
eager_blob_objects_
;
}
void
ForEachMirroredObject
(
const
std
::
function
<
void
(
vm
::
MirroredObject
*
compute
)
>&
)
const
;
void
ForEachMutMirroredObject
(
const
std
::
function
<
void
(
vm
::
MirroredObject
*
compute
)
>&
)
const
;
virtual
const
std
::
vector
<
std
::
string
>&
interfaces_op_names
()
const
=
0
;
virtual
const
std
::
vector
<
bool
>&
interfaces_valid
()
const
=
0
;
virtual
std
::
string
GetInterfaceBufferName
(
const
std
::
string
&
job_name
,
const
std
::
string
&
op_name
)
const
=
0
;
virtual
std
::
string
GetInterfaceCriticalSectionCallbackBufferName
(
const
std
::
string
&
job_name
)
const
=
0
;
virtual
std
::
string
GetInterfaceCriticalSectionWaitBufferName
(
const
std
::
string
&
job_name
)
const
=
0
;
virtual
void
AccessBlobByOpName
(
uint64_t
of_blob_ptr
,
const
std
::
string
&
op_name
)
=
0
;
void
FinishInvalidInterfaceEventRecords
();
void
Finish
();
void
ForEachInputEagerBlobObjects
(
void
(
*
DoEach
)(
EagerBlobObject
*
))
const
override
{
for
(
const
auto
&
eager_blob_object
:
*
eager_blob_objects_
)
{
DoEach
(
eager_blob_object
.
get
());
}
}
protected:
std
::
shared_ptr
<
NNGraphIf
>
nn_graph_
;
one
::
EagerBlobObjectListPtr
eager_blob_objects_
;
std
::
shared_ptr
<
HashMap
<
std
::
string
,
std
::
shared_ptr
<
SharedEventRecord
>>>
op_name2end_event_record_
;
HashMap
<
std
::
string
,
size_t
>
op_name2interface_index_
;
vm
::
Stream
*
vm_stream_
;
};
class
InputCriticalSectionBeginPhyInstrOperand
final
:
public
CriticalSectionBeginPhyInstrOperand
{
public:
InputCriticalSectionBeginPhyInstrOperand
(
const
std
::
shared_ptr
<
NNGraphIf
>&
nn_graph
,
const
one
::
EagerBlobObjectListPtr
&
eager_blob_objects
,
const
std
::
shared_ptr
<
HashMap
<
std
::
string
,
std
::
shared_ptr
<
SharedEventRecord
>>>&
op_name2end_event_record
,
vm
::
Stream
*
vm_stream
)
:
CriticalSectionBeginPhyInstrOperand
(
nn_graph
,
eager_blob_objects
,
op_name2end_event_record
,
vm_stream
),
input_dependences_
(),
output_dependences_
()
{
ForEachConstMirroredObject
(
SetInserter
(
&
input_dependences_
));
ForEachMutMirroredObject
(
SetInserter
(
&
output_dependences_
));
ForEachMut2MirroredObject
(
SetInserter
(
&
output_dependences_
));
CHECK_EQ
(
nn_graph
->
inputs_op_names
().
size
(),
eager_blob_objects
->
size
());
CHECK_EQ
(
nn_graph
->
inputs_op_names
().
size
(),
nn_graph
->
inputs_valid
().
size
());
for
(
int
i
=
0
;
i
<
nn_graph
->
inputs_op_names
().
size
();
++
i
)
{
CHECK
(
op_name2interface_index_
.
emplace
(
nn_graph
->
inputs_op_names
().
at
(
i
),
i
).
second
);
}
}
~
InputCriticalSectionBeginPhyInstrOperand
()
override
=
default
;
const
DependenceVector
&
input_dependences
()
const
override
{
return
input_dependences_
;
}
const
DependenceVector
&
output_dependences
()
const
override
{
return
output_dependences_
;
}
// for inputs
void
ForEachConstMirroredObject
(
const
std
::
function
<
void
(
vm
::
MirroredObject
*
compute
)
>&
DoEach
)
const
{
ForEachMirroredObject
(
DoEach
);
}
// for outputs
const
std
::
vector
<
std
::
string
>&
interfaces_op_names
()
const
override
{
return
nn_graph_
->
inputs_op_names
();
}
const
std
::
vector
<
bool
>&
interfaces_valid
()
const
override
{
return
nn_graph_
->
inputs_valid
();
}
std
::
string
GetInterfaceBufferName
(
const
std
::
string
&
job_name
,
const
std
::
string
&
op_name
)
const
override
{
return
GetInputBufferName
(
job_name
,
op_name
);
}
std
::
string
GetInterfaceCriticalSectionCallbackBufferName
(
const
std
::
string
&
job_name
)
const
override
{
return
GetInputCriticalSectionCallbackBufferName
(
job_name
);
}
std
::
string
GetInterfaceCriticalSectionWaitBufferName
(
const
std
::
string
&
job_name
)
const
override
{
return
GetInputCriticalSectionWaitBufferName
(
job_name
);
}
void
AccessBlobByOpName
(
uint64_t
of_blob_ptr
,
const
std
::
string
&
op_name
)
override
;
void
ForEachMut2MirroredObject
(
const
std
::
function
<
void
(
vm
::
MirroredObject
*
compute
)
>&
)
const
{}
private:
DependenceVector
input_dependences_
;
DependenceVector
output_dependences_
;
};
class
OutputCriticalSectionBeginPhyInstrOperand
final
:
public
CriticalSectionBeginPhyInstrOperand
{
public:
OutputCriticalSectionBeginPhyInstrOperand
(
const
std
::
shared_ptr
<
NNGraphIf
>&
nn_graph
,
const
one
::
EagerBlobObjectListPtr
&
eager_blob_objects
,
const
std
::
shared_ptr
<
HashMap
<
std
::
string
,
std
::
shared_ptr
<
SharedEventRecord
>>>&
op_name2end_event_record
,
vm
::
Stream
*
vm_stream
)
:
CriticalSectionBeginPhyInstrOperand
(
nn_graph
,
eager_blob_objects
,
op_name2end_event_record
,
vm_stream
),
input_dependences_
(),
output_dependences_
()
{
ForEachConstMirroredObject
(
SetInserter
(
&
input_dependences_
));
ForEachMutMirroredObject
(
SetInserter
(
&
output_dependences_
));
ForEachMut2MirroredObject
(
SetInserter
(
&
output_dependences_
));
CHECK_EQ
(
nn_graph
->
outputs_op_names
().
size
(),
eager_blob_objects
->
size
());
CHECK_EQ
(
nn_graph
->
outputs_op_names
().
size
(),
nn_graph
->
outputs_valid
().
size
());
for
(
int
i
=
0
;
i
<
nn_graph
->
outputs_op_names
().
size
();
++
i
)
{
CHECK
(
op_name2interface_index_
.
emplace
(
nn_graph
->
outputs_op_names
().
at
(
i
),
i
).
second
);
}
}
~
OutputCriticalSectionBeginPhyInstrOperand
()
override
=
default
;
const
DependenceVector
&
input_dependences
()
const
override
{
return
input_dependences_
;
}
const
DependenceVector
&
output_dependences
()
const
override
{
return
output_dependences_
;
}
// for inputs
void
ForEachConstMirroredObject
(
const
std
::
function
<
void
(
vm
::
MirroredObject
*
compute
)
>&
)
const
{}
// for outputs
void
ForEachMut2MirroredObject
(
const
std
::
function
<
void
(
vm
::
MirroredObject
*
compute
)
>&
DoEach
)
const
{
ForEachMirroredObject
(
DoEach
);
}
const
std
::
vector
<
std
::
string
>&
interfaces_op_names
()
const
override
{
return
nn_graph_
->
outputs_op_names
();
}
const
std
::
vector
<
bool
>&
interfaces_valid
()
const
override
{
return
nn_graph_
->
outputs_valid
();
}
std
::
string
GetInterfaceBufferName
(
const
std
::
string
&
job_name
,
const
std
::
string
&
op_name
)
const
override
{
return
GetOutputBufferName
(
job_name
,
op_name
);
}
std
::
string
GetInterfaceCriticalSectionCallbackBufferName
(
const
std
::
string
&
job_name
)
const
override
{
return
GetOutputCriticalSectionCallbackBufferName
(
job_name
);
}
std
::
string
GetInterfaceCriticalSectionWaitBufferName
(
const
std
::
string
&
job_name
)
const
override
{
return
GetOutputCriticalSectionWaitBufferName
(
job_name
);
}
void
AccessBlobByOpName
(
uint64_t
of_blob_ptr
,
const
std
::
string
&
op_name
)
override
;
private:
DependenceVector
input_dependences_
;
DependenceVector
output_dependences_
;
};
class
CriticalSectionEndPhyInstrOperand
:
public
PhyInstrOperand
{
public:
CriticalSectionEndPhyInstrOperand
(
const
std
::
shared_ptr
<
EagerBlobObject
>&
eager_blob_object
,
const
std
::
shared_ptr
<
SharedEventRecord
>&
event_record
,
vm
::
Stream
*
vm_stream
)
:
eager_blob_object_
(
eager_blob_object
),
event_record_
(
event_record
),
vm_stream_
(
vm_stream
)
{}
virtual
~
CriticalSectionEndPhyInstrOperand
()
=
default
;
const
std
::
shared_ptr
<
SharedEventRecord
>&
event_record
()
const
{
return
event_record_
;
}
void
ForEachMirroredObject
(
const
std
::
function
<
void
(
vm
::
MirroredObject
*
compute
)
>&
)
const
;
void
ForEachMutMirroredObject
(
const
std
::
function
<
void
(
vm
::
MirroredObject
*
compute
)
>&
)
const
;
void
ForEachInputEagerBlobObjects
(
void
(
*
DoEach
)(
EagerBlobObject
*
))
const
override
{
DoEach
(
eager_blob_object_
.
get
());
}
private:
std
::
shared_ptr
<
EagerBlobObject
>
eager_blob_object_
;
std
::
shared_ptr
<
SharedEventRecord
>
event_record_
;
vm
::
Stream
*
vm_stream_
;
};
class
InputCriticalSecondEndPhyInstrOperand
final
:
public
CriticalSectionEndPhyInstrOperand
{
public:
InputCriticalSecondEndPhyInstrOperand
(
const
std
::
shared_ptr
<
EagerBlobObject
>&
eager_blob_object
,
const
std
::
shared_ptr
<
SharedEventRecord
>&
event_record
,
vm
::
Stream
*
vm_stream
)
:
CriticalSectionEndPhyInstrOperand
(
eager_blob_object
,
event_record
,
vm_stream
),
input_dependences_
(),
output_dependences_
()
{
ForEachConstMirroredObject
(
SetInserter
(
&
input_dependences_
));
ForEachMutMirroredObject
(
SetInserter
(
&
output_dependences_
));
ForEachMut2MirroredObject
(
SetInserter
(
&
output_dependences_
));
}
~
InputCriticalSecondEndPhyInstrOperand
()
override
=
default
;
const
DependenceVector
&
input_dependences
()
const
override
{
return
input_dependences_
;
}
const
DependenceVector
&
output_dependences
()
const
override
{
return
output_dependences_
;
}
void
ForEachConstMirroredObject
(
const
std
::
function
<
void
(
vm
::
MirroredObject
*
compute
)
>&
DoEach
)
const
{
ForEachMirroredObject
(
DoEach
);
}
void
ForEachMut2MirroredObject
(
const
std
::
function
<
void
(
vm
::
MirroredObject
*
compute
)
>&
)
const
{}
private:
DependenceVector
input_dependences_
;
DependenceVector
output_dependences_
;
};
class
OutputCriticalSecondEndPhyInstrOperand
final
:
public
CriticalSectionEndPhyInstrOperand
{
public:
OutputCriticalSecondEndPhyInstrOperand
(
const
std
::
shared_ptr
<
EagerBlobObject
>&
eager_blob_object
,
const
std
::
shared_ptr
<
SharedEventRecord
>&
event_record
,
vm
::
Stream
*
vm_stream
)
:
CriticalSectionEndPhyInstrOperand
(
eager_blob_object
,
event_record
,
vm_stream
),
input_dependences_
(),
output_dependences_
()
{
ForEachConstMirroredObject
(
SetInserter
(
&
input_dependences_
));
ForEachMutMirroredObject
(
SetInserter
(
&
output_dependences_
));
ForEachMut2MirroredObject
(
SetInserter
(
&
output_dependences_
));
}
~
OutputCriticalSecondEndPhyInstrOperand
()
override
=
default
;
const
DependenceVector
&
input_dependences
()
const
override
{
return
input_dependences_
;
}
const
DependenceVector
&
output_dependences
()
const
override
{
return
output_dependences_
;
}
// for inputs
void
ForEachConstMirroredObject
(
const
std
::
function
<
void
(
vm
::
MirroredObject
*
compute
)
>&
)
const
{}
// for outputs
void
ForEachMut2MirroredObject
(
const
std
::
function
<
void
(
vm
::
MirroredObject
*
compute
)
>&
DoEach
)
const
{
ForEachMirroredObject
(
DoEach
);
}
private:
DependenceVector
input_dependences_
;
DependenceVector
output_dependences_
;
};
}
// namespace vm
}
// namespace oneflow
#endif // ONEFLOW_CORE_EAGER_CRITICAL_SECTION_PHY_INSTR_OPERAND_H_
oneflow/core/eager/eager_blob_object.cpp
View file @
a715222c
...
@@ -18,53 +18,77 @@ limitations under the License.
...
@@ -18,53 +18,77 @@ limitations under the License.
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/framework/shut_down_util.h"
#include "oneflow/core/framework/shut_down_util.h"
#include "oneflow/core/common/shape_vec.h"
#include "oneflow/core/common/shape_vec.h"
#include "oneflow/core/common/tensor_meta.h"
namespace
oneflow
{
namespace
oneflow
{
namespace
vm
{
namespace
vm
{
EagerBlobObject
::
EagerBlobObject
(
const
std
::
shared_ptr
<
MemoryCase
>&
mem_case
,
EagerBlobObject
::
EagerBlobObject
(
const
std
::
shared_ptr
<
Shape
>&
shape
,
const
std
::
shared_ptr
<
MemoryCase
>&
mem_case
,
const
std
::
shared_ptr
<
Stride
>&
stride
,
DataType
data_type
,
const
Symbol
<
one
::
LocalTensorMeta
>&
static_local_tensor_meta
,
const
std
::
shared_ptr
<
TensorStorage
>&
tensor_storage
,
const
std
::
shared_ptr
<
const
one
::
MutLocalTensorMeta
>&
dynamic_local_tensor_meta
,
const
intrusive
::
shared_ptr
<
LocalDepObject
>&
dep_object
)
DataType
data_type
,
const
std
::
shared_ptr
<
TensorStorage
>&
tensor_storage
,
const
intrusive
::
shared_ptr
<
LocalDepObject
>&
dep_object
)
:
is_dynamic_
(
false
),
:
is_dynamic_
(
false
),
mem_case_
(
mem_case
),
mem_case_
(
mem_case
),
data_type_
(
data_type
),
data_type_
(
data_type
),
shape_
(
shape
),
stride_
(
stride
),
storage_offset_
(
0
),
storage_offset_
(
0
),
tensor_storage_
(
tensor_storage
),
tensor_storage_
(
tensor_storage
),
mem_ptr_for_allocation_compuation_pipelining_
(
nullptr
),
inited_mem_ptr_for_allocation_compuation_pipelining_
(
false
),
is_non_pod_object_placement_newed_
(
false
),
is_shape_synced_
(
true
),
compute_local_dep_object_
(
dep_object
),
compute_local_dep_object_
(
dep_object
),
blob_desc_
(
shape
,
stride
,
data_type
)
{
static_local_tensor_meta_
(
static_local_tensor_meta
),
CHECK
(
static_cast
<
bool
>
(
shape
));
dynamic_local_tensor_meta_
(
dynamic_local_tensor_meta
)
{
CHECK
(
static_cast
<
bool
>
(
stride
));
CHECK
(
static_cast
<
bool
>
(
tensor_storage
));
CHECK
(
static_cast
<
bool
>
(
tensor_storage
));
}
}
Blob
*
EagerBlobObject
::
blob
()
{
// user_op::TensorDesc overrides
if
(
!
blob_
)
{
const
Shape
&
EagerBlobObject
::
shape
()
const
{
blob_
.
reset
(
new
Blob
(
*
mem_case_
,
&
blob_desc_
,
mut_header_ptr
(),
mut_dptr
<
char
>
()));
if
(
dynamic_local_tensor_meta_
)
{
return
dynamic_local_tensor_meta_
->
shape
();
}
else
{
return
static_local_tensor_meta_
->
shape
();
}
}
const
Stride
&
EagerBlobObject
::
stride
()
const
{
if
(
dynamic_local_tensor_meta_
)
{
return
dynamic_local_tensor_meta_
->
stride
();
}
else
{
return
static_local_tensor_meta_
->
stride
();
}
}
return
blob_
.
get
();
}
}
void
EagerBlobObject
::
set_storage_offset
(
const
int64_t
offset
)
{
storage_offset_
=
offset
;
}
void
EagerBlobObject
::
set_shape
(
const
Shape
&
shape
)
{
CHECK
(
dynamic_local_tensor_meta_
);
std
::
const_pointer_cast
<
one
::
MutLocalTensorMeta
>
(
dynamic_local_tensor_meta_
)
->
set_shape
(
shape
);
}
void
EagerBlobObject
::
set_stride
(
const
Stride
&
stride
)
{
CHECK
(
dynamic_local_tensor_meta_
);
std
::
const_pointer_cast
<
one
::
MutLocalTensorMeta
>
(
dynamic_local_tensor_meta_
)
->
set_stride
(
stride
);
}
void
EagerBlobObject
::
TryInitNonPODTypeEagerBlobObjectIfNeed
()
{
MutShapeView
EagerBlobObject
::
mut_shape_view
()
{
if
(
!
IsPODDataType
(
data_type
()))
{
CHECK
(
dynamic_local_tensor_meta_
);
if
(
!
is_non_pod_object_placement_newed_
)
{
return
*
const_cast
<
Shape
*>
(
dynamic_local_tensor_meta_
->
shape_ptr
().
get
());
InitNonPODTypeEagerBlobObjectIfNeed
(
tensor_storage_
->
non_pod_allocator
(),
this
);
}
is_non_pod_object_placement_newed_
=
true
;
}
std
::
shared_ptr
<
const
Shape
>
EagerBlobObject
::
shape_ptr
()
const
{
if
(
dynamic_local_tensor_meta_
)
{
return
dynamic_local_tensor_meta_
->
shape_ptr
();
}
else
{
return
static_local_tensor_meta_
->
shape_ptr
();
}
}
std
::
shared_ptr
<
const
Stride
>
EagerBlobObject
::
stride_ptr
()
const
{
if
(
dynamic_local_tensor_meta_
)
{
return
dynamic_local_tensor_meta_
->
stride_ptr
();
}
else
{
return
static_local_tensor_meta_
->
stride_ptr
();
}
}
}
}
Maybe
<
void
>
EagerBlobObject
::
TryAllocateBlobBodyMemory
(
DeviceCtx
*
device_ctx
)
{
void
EagerBlobObject
::
set_storage_offset
(
const
int64_t
offset
)
{
storage_offset_
=
offset
;
}
vm
::
Allocator
*
allocator
=
device_ctx
->
mut_allocator
();
Maybe
<
bool
>
EagerBlobObject
::
TryAllocateBlobBodyMemory
(
vm
::
Allocator
*
allocator
)
{
size_t
required_body_bytes
=
AlignedByteSizeOfBlobBody
();
size_t
required_body_bytes
=
AlignedByteSizeOfBlobBody
();
if
(
required_body_bytes
==
0
)
{
if
(
required_body_bytes
==
0
)
{
CHECK_ISNULL_OR_RETURN
(
tensor_storage_
->
blob_dptr
());
CHECK_ISNULL_OR_RETURN
(
tensor_storage_
->
blob_dptr
());
...
@@ -81,10 +105,10 @@ Maybe<void> EagerBlobObject::TryAllocateBlobBodyMemory(DeviceCtx* device_ctx) {
...
@@ -81,10 +105,10 @@ Maybe<void> EagerBlobObject::TryAllocateBlobBodyMemory(DeviceCtx* device_ctx) {
};
};
tensor_storage_
->
set_blob_dptr
(
std
::
unique_ptr
<
char
,
std
::
function
<
void
(
char
*
)
>>
(
dptr
,
Free
),
tensor_storage_
->
set_blob_dptr
(
std
::
unique_ptr
<
char
,
std
::
function
<
void
(
char
*
)
>>
(
dptr
,
Free
),
required_body_bytes
);
required_body_bytes
);
InitMemPtrForAllocationComputationPipelining
();
InitNonPODTypeEagerBlobObjectIfNeed
(
tensor_storage_
->
non_pod_allocator
(),
this
);
return
true
;
}
}
InitOrCheckMemPtrForAllocationComputationPipelining
();
return
false
;
return
Maybe
<
void
>::
Ok
();
}
}
}
// namespace vm
}
// namespace vm
...
...
oneflow/core/eager/eager_blob_object.h
View file @
a715222c
...
@@ -18,6 +18,7 @@ limitations under the License.
...
@@ -18,6 +18,7 @@ limitations under the License.
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/optional.h"
#include "oneflow/core/common/optional.h"
#include "oneflow/core/common/op_args_reserved_size.h"
#include "oneflow/core/eager/local_dep_object.h"
#include "oneflow/core/eager/local_dep_object.h"
#include "oneflow/core/device/device_context.h"
#include "oneflow/core/device/device_context.h"
#include "oneflow/core/memory/memory_allocator.h"
#include "oneflow/core/memory/memory_allocator.h"
...
@@ -25,24 +26,34 @@ limitations under the License.
...
@@ -25,24 +26,34 @@ limitations under the License.
#include "oneflow/core/framework/stream.h"
#include "oneflow/core/framework/stream.h"
#include "oneflow/core/framework/tensor_methods.h"
#include "oneflow/core/framework/tensor_methods.h"
#include "oneflow/core/framework/user_op_tensor.h"
#include "oneflow/core/framework/user_op_tensor.h"
#include "oneflow/core/
framework
/tensor_desc.h"
#include "oneflow/core/
common
/tensor_desc.h"
#include "oneflow/core/register/blob.h"
#include "oneflow/core/register/blob.h"
namespace
oneflow
{
namespace
oneflow
{
namespace
one
{
class
LocalTensorMeta
;
class
MutLocalTensorMeta
;
}
// namespace one
namespace
vm
{
namespace
vm
{
class
TensorStorage
{
class
TensorStorage
{
public:
public:
TensorStorage
()
TensorStorage
()
:
non_pod_allocator_
(
std
::
make_unique
<
MemoryAllocator
>
()),
:
blob_bytes_
(
0
),
non_pod_allocator_
(
std
::
make_unique
<
MemoryAllocator
>
()),
producer_stream_
(
NullOpt
),
producer_stream_
(
NullOpt
),
last_used_stream_
(
NullOpt
)
{}
last_used_stream_
(
NullOpt
)
{}
~
TensorStorage
()
{
virtual
~
TensorStorage
()
{
for
(
const
auto
&
hook
:
storage_delete_hooks_
)
{
hook
();
}
for
(
const
auto
&
hook
:
storage_delete_hooks_
)
{
hook
();
}
}
}
virtual
bool
is_allocated_in_vm
()
const
=
0
;
size_t
blob_bytes
()
const
{
return
blob_bytes_
;
}
size_t
blob_bytes
()
const
{
return
blob_bytes_
;
}
char
*
blob_dptr
()
{
return
blob_dptr_
.
get
();
}
char
*
blob_dptr
()
{
return
blob_dptr_
.
get
();
}
...
@@ -84,58 +95,77 @@ class TensorStorage {
...
@@ -84,58 +95,77 @@ class TensorStorage {
std
::
vector
<
std
::
function
<
void
()
>>
storage_delete_hooks_
;
std
::
vector
<
std
::
function
<
void
()
>>
storage_delete_hooks_
;
};
};
class
InsideVmTensorStorage
:
public
TensorStorage
{
public:
InsideVmTensorStorage
()
=
default
;
~
InsideVmTensorStorage
()
=
default
;
bool
is_allocated_in_vm
()
const
override
{
return
true
;
}
};
class
OutsideVmTensorStorage
:
public
TensorStorage
{
public:
OutsideVmTensorStorage
()
=
default
;
~
OutsideVmTensorStorage
()
=
default
;
bool
is_allocated_in_vm
()
const
override
{
return
false
;
}
};
class
EagerBlobObject
final
:
public
user_op
::
Tensor
,
class
EagerBlobObject
final
:
public
user_op
::
Tensor
,
public
user_op
::
TensorDesc
,
public
user_op
::
TensorDesc
,
public
std
::
enable_shared_from_this
<
EagerBlobObject
>
{
public
std
::
enable_shared_from_this
<
EagerBlobObject
>
{
public:
public:
EagerBlobObject
(
const
EagerBlobObject
&
)
=
delete
;
EagerBlobObject
(
const
EagerBlobObject
&
)
=
delete
;
EagerBlobObject
(
EagerBlobObject
&&
)
=
delete
;
EagerBlobObject
(
EagerBlobObject
&&
)
=
delete
;
EagerBlobObject
(
const
std
::
shared_ptr
<
MemoryCase
>&
mem_case
,
const
std
::
shared_ptr
<
Shape
>&
shape
,
EagerBlobObject
(
const
std
::
shared_ptr
<
MemoryCase
>&
mem_case
,
const
std
::
shared_ptr
<
Stride
>&
stride
,
DataType
data_type
,
const
Symbol
<
one
::
LocalTensorMeta
>&
static_local_tensor_meta
,
const
std
::
shared_ptr
<
TensorStorage
>&
tensor_storage
)
const
std
::
shared_ptr
<
const
one
::
MutLocalTensorMeta
>&
dynamic_local_tensor_meta
,
:
EagerBlobObject
(
mem_case
,
shape
,
stride
,
data_type
,
tensor_storage
,
DataType
data_type
,
const
std
::
shared_ptr
<
TensorStorage
>&
tensor_storage
)
intrusive
::
shared_ptr
<
LocalDepObject
>
())
{}
:
EagerBlobObject
(
mem_case
,
static_local_tensor_meta
,
dynamic_local_tensor_meta
,
data_type
,
EagerBlobObject
(
const
std
::
shared_ptr
<
MemoryCase
>&
mem_case
,
const
std
::
shared_ptr
<
Shape
>&
shape
,
tensor_storage
,
intrusive
::
shared_ptr
<
LocalDepObject
>
())
{}
const
std
::
shared_ptr
<
Stride
>&
stride
,
DataType
data_type
,
EagerBlobObject
(
const
std
::
shared_ptr
<
MemoryCase
>&
mem_case
,
const
std
::
shared_ptr
<
TensorStorage
>&
tensor_storage
,
const
Symbol
<
one
::
LocalTensorMeta
>&
static_local_tensor_meta
,
const
std
::
shared_ptr
<
const
one
::
MutLocalTensorMeta
>&
dynamic_local_tensor_meta
,
DataType
data_type
,
const
std
::
shared_ptr
<
TensorStorage
>&
tensor_storage
,
const
intrusive
::
shared_ptr
<
LocalDepObject
>&
dep_object
);
const
intrusive
::
shared_ptr
<
LocalDepObject
>&
dep_object
);
~
EagerBlobObject
()
{
tensor_storage_
.
reset
();
}
~
EagerBlobObject
()
{
tensor_storage_
.
reset
();
}
const
std
::
shared_ptr
<
const
one
::
MutLocalTensorMeta
>&
mut_tensor_meta
()
{
return
dynamic_local_tensor_meta_
;
}
// Getters
const
Symbol
<
one
::
LocalTensorMeta
>&
tensor_meta
()
const
{
return
static_local_tensor_meta_
;
}
// user_op::TensorDesc overrides
// user_op::TensorDesc overrides
const
Shape
&
shape
()
const
override
{
return
*
shape_
;
}
const
Shape
&
shape
()
const
override
;
Shape
*
mut_shape
()
override
{
return
shape_
.
get
();
}
const
Stride
&
stride
()
const
override
;
const
Stride
&
stride
()
const
override
{
return
*
stride_
;
}
Stride
*
mut_stride
()
override
{
return
stride_
.
get
();
}
DataType
data_type
()
const
override
{
return
data_type_
;
}
DataType
data_type
()
const
override
{
return
data_type_
;
}
DataType
*
mut_data_type
()
override
{
return
&
data_type_
;
}
bool
is_dynamic
()
const
override
{
return
is_dynamic_
;
}
bool
is_dynamic
()
const
override
{
return
is_dynamic_
;
}
bool
*
mut_is_dynamic
()
override
{
return
&
is_dynamic_
;
}
void
set_shape
(
const
Shape
&
shape
)
override
;
void
set_stride
(
const
Stride
&
stride
)
override
;
void
set_data_type
(
DataType
data_type
)
override
{
data_type_
=
data_type
;
}
void
set_is_dynamic
(
bool
is_dynamic
)
override
{
is_dynamic_
=
is_dynamic
;
}
void
set_is_dynamic
(
bool
is_dynamic
)
override
{
is_dynamic_
=
is_dynamic
;
}
// user_op::Tensor overrides
// user_op::Tensor overrides
ShapeView
shape_view
()
const
override
{
return
*
shape
_
;
}
ShapeView
shape_view
()
const
override
{
return
shape
()
;
}
MutShapeView
mut_shape_view
()
override
{
return
*
shape_
;
}
MutShapeView
mut_shape_view
()
override
;
const
MemoryCase
&
mem_case
()
const
override
{
return
*
mem_case_
;
}
const
MemoryCase
&
mem_case
()
const
override
{
return
*
mem_case_
;
}
const
void
*
raw_dptr
()
const
override
{
const
void
*
raw_dptr
()
const
override
{
CHECK
(
inited_mem_ptr_for_allocation_compuation_pipelining_
)
char
*
ptr
=
tensor_storage_
->
blob_dptr
();
<<
"mem_ptr_for_allocation_compuation_pipelining_ not initialized. Please check if there "
if
(
tensor_storage_
->
blob_bytes
()
>
0
)
{
CHECK_NOTNULL
(
ptr
);
}
"are any EagerBlobObjects created outside vm"
;
return
ptr
+
storage_offset_
*
GetSizeOfDataType
(
data_type_
);
return
mem_ptr_for_allocation_compuation_pipelining_
+
storage_offset_
*
GetSizeOfDataType
(
data_type_
);
}
}
void
*
mut_raw_dptr
()
override
{
return
const_cast
<
void
*>
(
raw_dptr
());
}
void
*
mut_raw_dptr
()
override
{
return
const_cast
<
void
*>
(
raw_dptr
());
}
void
set_storage_offset
(
const
int64_t
offset
);
void
set_storage_offset
(
const
int64_t
offset
);
[[
deprecated
(
"
\"
Blob
\"
will be removed in eager. Please avoid to use this method whenever "
// Returns true if allocate successfully.
"possible. Almost all methods of `Blob` are also in `EagerBlobObject`."
)]]
Blob
*
Maybe
<
bool
>
TryAllocateBlobBodyMemory
(
vm
::
Allocator
*
allocator
);
blob
();
Maybe
<
void
>
TryAllocateBlobBodyMemory
(
DeviceCtx
*
device_ctx
);
Maybe
<
void
>
DeallocateBlobDataPtr
()
{
Maybe
<
void
>
DeallocateBlobDataPtr
()
{
tensor_storage_
->
Release
();
tensor_storage_
->
Release
();
tensor_storage_
.
reset
(
new
TensorStorage
);
tensor_storage_
.
reset
(
new
InsideVm
TensorStorage
()
);
return
Maybe
<
void
>::
Ok
();
return
Maybe
<
void
>::
Ok
();
}
}
void
RegisterStorageDeleteHook
(
const
std
::
function
<
void
()
>&
hook
)
{
void
RegisterStorageDeleteHook
(
const
std
::
function
<
void
()
>&
hook
)
{
...
@@ -149,10 +179,6 @@ class EagerBlobObject final : public user_op::Tensor,
...
@@ -149,10 +179,6 @@ class EagerBlobObject final : public user_op::Tensor,
std
::
shared_ptr
<
TensorStorage
>&
tensor_storage
()
{
return
tensor_storage_
;
}
std
::
shared_ptr
<
TensorStorage
>&
tensor_storage
()
{
return
tensor_storage_
;
}
bool
is_shape_synced
()
const
{
return
is_shape_synced_
;
}
void
set_is_shape_synced
(
bool
val
)
{
is_shape_synced_
=
val
;
}
const
Optional
<
Symbol
<::
oneflow
::
Stream
>>&
producer_stream
()
const
{
const
Optional
<
Symbol
<::
oneflow
::
Stream
>>&
producer_stream
()
const
{
return
tensor_storage_
->
producer_stream
();
return
tensor_storage_
->
producer_stream
();
}
}
...
@@ -167,10 +193,10 @@ class EagerBlobObject final : public user_op::Tensor,
...
@@ -167,10 +193,10 @@ class EagerBlobObject final : public user_op::Tensor,
tensor_storage_
->
set_last_used_stream
(
last_used_stream
);
tensor_storage_
->
set_last_used_stream
(
last_used_stream
);
}
}
std
::
shared_ptr
<
const
Shape
>
shape_ptr
()
const
{
return
shape_
;
}
std
::
shared_ptr
<
const
Shape
>
shape_ptr
()
const
;
std
::
shared_ptr
<
const
Stride
>
stride_ptr
()
const
{
return
stride_
;
}
std
::
shared_ptr
<
const
Stride
>
stride_ptr
()
const
;
size_t
ByteSizeOfBlobBody
()
const
{
return
shape
_
->
elem_cnt
()
*
GetSizeOfDataType
(
data_type_
);
}
size_t
ByteSizeOfBlobBody
()
const
{
return
shape
().
elem_cnt
()
*
GetSizeOfDataType
(
data_type_
);
}
size_t
AlignedByteSizeOfBlobBody
()
const
{
size_t
AlignedByteSizeOfBlobBody
()
const
{
return
RoundUp
(
ByteSizeOfBlobBody
(),
kBlobBodyAlignSize
);
return
RoundUp
(
ByteSizeOfBlobBody
(),
kBlobBodyAlignSize
);
}
}
...
@@ -179,52 +205,28 @@ class EagerBlobObject final : public user_op::Tensor,
...
@@ -179,52 +205,28 @@ class EagerBlobObject final : public user_op::Tensor,
return
RoundUp
(
ByteSizeOfBlobHeader
(),
kBlobHeaderAlignSize
);
return
RoundUp
(
ByteSizeOfBlobHeader
(),
kBlobHeaderAlignSize
);
}
}
const
char
*
header_ptr
()
const
{
return
reinterpret_cast
<
const
char
*>
(
shape_
->
dim_vec
().
data
());
}
const
char
*
header_ptr
()
const
{
return
reinterpret_cast
<
const
char
*>
(
shape
().
dim_vec
().
data
());
}
char
*
mut_header_ptr
()
{
return
reinterpret_cast
<
char
*>
(
shape_
->
dim_vec
().
data
());
}
char
*
mut_header_ptr
()
{
return
reinterpret_cast
<
char
*>
(
const_cast
<
int64_t
*>
(
shape
().
dim_vec
().
data
()));
void
InitOrCheckMemPtrForAllocationComputationPipelining
()
{
auto
*
ptr
=
tensor_storage_
->
blob_dptr
();
if
(
inited_mem_ptr_for_allocation_compuation_pipelining_
)
{
CHECK_EQ
(
mem_ptr_for_allocation_compuation_pipelining_
,
ptr
);
}
else
{
mem_ptr_for_allocation_compuation_pipelining_
=
ptr
;
inited_mem_ptr_for_allocation_compuation_pipelining_
=
true
;
}
}
}
void
TryInitNonPODTypeEagerBlobObjectIfNeed
();
private:
private:
void
InitMemPtrForAllocationComputationPipelining
()
{
auto
*
ptr
=
tensor_storage_
->
blob_dptr
();
CHECK
(
!
inited_mem_ptr_for_allocation_compuation_pipelining_
)
<<
"mem_ptr_for_allocation_compuation_pipelining_ has been initialized."
;
mem_ptr_for_allocation_compuation_pipelining_
=
ptr
;
inited_mem_ptr_for_allocation_compuation_pipelining_
=
true
;
}
bool
is_dynamic_
;
bool
is_dynamic_
;
std
::
shared_ptr
<
MemoryCase
>
mem_case_
;
std
::
shared_ptr
<
MemoryCase
>
mem_case_
;
DataType
data_type_
;
DataType
data_type_
;
std
::
shared_ptr
<
Shape
>
shape_
;
std
::
shared_ptr
<
Stride
>
stride_
;
int64_t
storage_offset_
;
int64_t
storage_offset_
;
std
::
shared_ptr
<
TensorStorage
>
tensor_storage_
;
std
::
shared_ptr
<
TensorStorage
>
tensor_storage_
;
// For allocation-computation pipeline, the value of mem_ptr_for_allocation_compuation_pipelining_
// are kept even after tensor_storage_.reset().
char
*
mem_ptr_for_allocation_compuation_pipelining_
;
bool
inited_mem_ptr_for_allocation_compuation_pipelining_
;
bool
is_non_pod_object_placement_newed_
;
std
::
atomic
<
bool
>
is_shape_synced_
;
bool
pin_memory_
;
intrusive
::
shared_ptr
<
LocalDepObject
>
compute_local_dep_object_
;
intrusive
::
shared_ptr
<
LocalDepObject
>
compute_local_dep_object_
;
// NOTE: Will be removed soon. Avoid to use it whenever possible.
Symbol
<
one
::
LocalTensorMeta
>
static_local_tensor_meta_
;
BlobDesc
blob_desc_
;
std
::
shared_ptr
<
const
one
::
MutLocalTensorMeta
>
dynamic_local_tensor_meta_
;
std
::
unique_ptr
<
Blob
>
blob_
;
};
};
using
EagerBlobObjectList
=
small_vector
<
std
::
shared_ptr
<
vm
::
EagerBlobObject
>
,
kOpArgsReservedSize
>
;
using
EagerBlobObjectListPtr
=
std
::
shared_ptr
<
const
EagerBlobObjectList
>
;
}
// namespace vm
}
// namespace vm
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_EAGER_EAGER_BLOB_OBJECT_H_
#endif // ONEFLOW_CORE_EAGER_EAGER_BLOB_OBJECT_H_
oneflow/core/eager/lazy_job_instruction_type.h
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EAGER_LAZY_JOB_INSTRUCTION_TYPE_H_
#define ONEFLOW_CORE_EAGER_LAZY_JOB_INSTRUCTION_TYPE_H_
#include "oneflow/core/vm/lazy_job_device_context.h"
#include "oneflow/core/eager/lazy_job_phy_instr_operand.h"
#include "oneflow/core/framework/nn_graph_if.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/common/of_unused.h"
#include "oneflow/core/vm/instruction.h"
#include "oneflow/core/vm/instruction_type.h"
#include "oneflow/core/job/job_instance.h"
#include "oneflow/core/common/buffer_manager.h"
#include "oneflow/core/common/singleton.h"
#include "oneflow/core/vm/stream.h"
#include "oneflow/core/vm/thread_ctx.h"
#include "oneflow/core/register/ofblob.h"
#include "oneflow/core/vm/naive_instruction_status_querier.h"
#include "oneflow/core/profiler/profiler.h"
#include "oneflow/core/kernel/kernel_util.h"
namespace
oneflow
{
class
LazyJobInstance
final
:
public
JobInstance
{
public:
LazyJobInstance
(
const
LazyJobInstance
&
)
=
delete
;
LazyJobInstance
(
LazyJobInstance
&&
)
=
delete
;
~
LazyJobInstance
()
override
=
default
;
LazyJobInstance
(
const
std
::
string
&
job_name
,
const
std
::
function
<
void
()
>&
finish_cb
)
:
job_name_
(
job_name
),
finish_cb_
(
finish_cb
)
{}
std
::
string
job_name
()
const
override
{
return
job_name_
;
}
void
Finish
()
const
override
{
finish_cb_
();
}
std
::
string
sole_input_op_name_in_user_job
()
const
override
{
UNIMPLEMENTED
();
return
std
::
string
();
}
std
::
string
sole_output_op_name_in_user_job
()
const
override
{
UNIMPLEMENTED
();
return
std
::
string
();
}
void
PushBlob
(
uint64_t
ofblob_ptr
)
const
override
{
UNIMPLEMENTED
();
}
void
PullBlob
(
uint64_t
ofblob_ptr
)
const
override
{
UNIMPLEMENTED
();
}
private:
const
std
::
string
job_name_
;
const
std
::
function
<
void
()
>
finish_cb_
;
};
namespace
vm
{
class
LaunchLazyJobInstructionType
final
:
public
InstructionType
{
// NOLINT
public:
LaunchLazyJobInstructionType
(
const
LaunchLazyJobInstructionType
&
)
=
delete
;
LaunchLazyJobInstructionType
(
LaunchLazyJobInstructionType
&&
)
=
delete
;
LaunchLazyJobInstructionType
()
=
default
;
~
LaunchLazyJobInstructionType
()
=
default
;
std
::
string
DebugName
(
const
vm
::
Instruction
&
)
const
override
{
return
"LaunchLazyJob"
;
}
Maybe
<
void
>
Prepare
(
vm
::
Instruction
*
instruction
)
const
override
{
return
Maybe
<
void
>::
Ok
();
}
void
Compute
(
vm
::
Instruction
*
instruction
)
const
override
{
const
auto
&
cur_nn_graph
=
GetCurNNGraph
(
instruction
);
auto
*
device_ctx
=
GetLazyJobDeviceCtx
(
instruction
);
static
thread_local
int64_t
run_id
=
0
;
{
OF_PROFILER_RANGE_GUARD
(
"WaitUntilQueueEmptyIfFrontNNGraphNotEquals"
);
device_ctx
->
WaitUntilQueueEmptyIfFrontNNGraphNotEquals
(
cur_nn_graph
);
}
{
OF_PROFILER_RANGE_GUARD
(
"Send all buffers to BufferMgr"
);
const
auto
&
job_instance
=
MakeJobInstance
(
instruction
);
const
auto
&
job_name
=
job_instance
->
job_name
();
auto
*
buffer_mgr
=
Singleton
<
BufferMgr
<
std
::
shared_ptr
<
JobInstance
>>>::
Get
();
buffer_mgr
->
Get
(
GetCallbackNotifierBufferName
(
job_name
))
->
Push
(
job_instance
);
buffer_mgr
->
Get
(
GetSourceTickBufferName
(
job_name
))
->
Push
(
job_instance
);
}
OF_UNUSED
(
run_id
);
// disable compiler warning.
OF_PROFILER_RANGE_GUARD
(
"EnqueueNNGraph"
);
device_ctx
->
EnqueueNNGraph
(
cur_nn_graph
);
}
private:
LazyJobDeviceCtx
*
GetLazyJobDeviceCtx
(
Instruction
*
instruction
)
const
{
auto
*
stream
=
instruction
->
mut_stream
();
auto
*
device_ctx
=
dynamic_cast
<
LazyJobDeviceCtx
*>
(
stream
->
device_ctx
().
get
());
CHECK_NOTNULL
(
device_ctx
);
return
device_ctx
;
}
std
::
shared_ptr
<
NNGraphIf
>
GetCurNNGraph
(
Instruction
*
instruction
)
const
{
const
auto
*
ptr
=
instruction
->
phy_instr_operand
().
get
();
const
auto
*
phy_instr_operand
=
dynamic_cast
<
const
LaunchLazyJobPhyInstrOperand
*>
(
ptr
);
CHECK_NOTNULL
(
phy_instr_operand
);
return
phy_instr_operand
->
nn_graph
();
}
std
::
shared_ptr
<
LazyJobInstance
>
MakeJobInstance
(
Instruction
*
instruction
)
const
{
const
auto
*
ptr
=
instruction
->
phy_instr_operand
().
get
();
const
auto
*
phy_instr_operand
=
dynamic_cast
<
const
LaunchLazyJobPhyInstrOperand
*>
(
ptr
);
CHECK_NOTNULL
(
phy_instr_operand
);
const
auto
&
nn_graph
=
phy_instr_operand
->
nn_graph
();
const
auto
&
FinishCb
=
[
this
,
instruction
]()
{
auto
*
device_ctx
=
GetLazyJobDeviceCtx
(
instruction
);
device_ctx
->
DequeueNNGraph
();
auto
*
status_buffer
=
instruction
->
mut_status_buffer
();
NaiveInstrStatusQuerier
::
MutCast
(
status_buffer
->
mut_buffer
())
->
set_done
();
};
return
std
::
make_shared
<
LazyJobInstance
>
(
nn_graph
->
job_name
(),
FinishCb
);
}
};
}
// namespace vm
}
// namespace oneflow
#endif // ONEFLOW_CORE_EAGER_LAZY_JOB_INSTRUCTION_TYPE_H_
oneflow/core/eager/lazy_job_phy_instr_operand.cpp
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/eager/lazy_job_phy_instr_operand.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/stream.h"
#include "oneflow/core/vm/virtual_machine.h"
namespace
oneflow
{
namespace
vm
{
void
LaunchLazyJobPhyInstrOperand
::
ForEachMutMirroredObject
(
const
std
::
function
<
void
(
vm
::
MirroredObject
*
compute
)
>&
DoEach
)
const
{
for
(
const
auto
&
eager_blob_object
:
*
param_blob_objects_
)
{
DoEach
(
CHECK_JUST
(
eager_blob_object
->
compute_local_dep_object
()));
}
DoEach
(
CHECK_JUST
(
SingletonMaybe
<
VirtualMachine
>
())
->
FindOrCreateTransportLocalDepObject
()
.
Mutable
());
}
}
// namespace vm
}
// namespace oneflow
oneflow/core/eager/lazy_job_phy_instr_operand.h
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EAGER_LAZY_JOB_PHY_INSTR_OPERAND_H_
#define ONEFLOW_CORE_EAGER_LAZY_JOB_PHY_INSTR_OPERAND_H_
#include "oneflow/core/eager/eager_blob_object.h"
#include "oneflow/core/eager/local_dep_object.h"
#include "oneflow/core/device/event_record.h"
#include "oneflow/core/eager/critical_section_phy_instr_operand.h"
#include "oneflow/core/framework/nn_graph_if.h"
#include "oneflow/core/common/notifier.h"
namespace
oneflow
{
namespace
one
{
using
EagerBlobObjectListPtr
=
std
::
shared_ptr
<
const
std
::
vector
<
std
::
shared_ptr
<
vm
::
EagerBlobObject
>>>
;
}
namespace
vm
{
class
LaunchLazyJobPhyInstrOperand
final
:
public
PhyInstrOperand
{
public:
LaunchLazyJobPhyInstrOperand
(
const
LaunchLazyJobPhyInstrOperand
&
)
=
delete
;
LaunchLazyJobPhyInstrOperand
(
LaunchLazyJobPhyInstrOperand
&&
)
=
delete
;
~
LaunchLazyJobPhyInstrOperand
()
override
=
default
;
LaunchLazyJobPhyInstrOperand
(
const
std
::
shared_ptr
<
NNGraphIf
>&
nn_graph
,
const
one
::
EagerBlobObjectListPtr
&
param_blob_objects
)
:
nn_graph_
(
nn_graph
),
param_blob_objects_
(
param_blob_objects
),
input_dependences_
(),
output_dependences_
()
{
ForEachConstMirroredObject
(
SetInserter
(
&
input_dependences_
));
ForEachMutMirroredObject
(
SetInserter
(
&
output_dependences_
));
ForEachMut2MirroredObject
(
SetInserter
(
&
output_dependences_
));
stream_sequential_dependence_
=
nullptr
;
}
const
std
::
shared_ptr
<
NNGraphIf
>&
nn_graph
()
const
{
return
nn_graph_
;
}
const
DependenceVector
&
input_dependences
()
const
override
{
return
input_dependences_
;
}
const
DependenceVector
&
output_dependences
()
const
override
{
return
output_dependences_
;
}
void
ForEachConstMirroredObject
(
const
std
::
function
<
void
(
vm
::
MirroredObject
*
compute
)
>&
)
const
{}
void
ForEachMutMirroredObject
(
const
std
::
function
<
void
(
vm
::
MirroredObject
*
compute
)
>&
)
const
;
void
ForEachMut2MirroredObject
(
const
std
::
function
<
void
(
vm
::
MirroredObject
*
compute
)
>&
)
const
{}
void
ForEachInputEagerBlobObjects
(
void
(
*
DoEach
)(
EagerBlobObject
*
))
const
override
{
for
(
const
auto
&
eager_blob_object
:
*
param_blob_objects_
)
{
DoEach
(
eager_blob_object
.
get
());
}
}
private:
std
::
shared_ptr
<
NNGraphIf
>
nn_graph_
;
one
::
EagerBlobObjectListPtr
param_blob_objects_
;
DependenceVector
input_dependences_
;
DependenceVector
output_dependences_
;
};
}
// namespace vm
}
// namespace oneflow
#endif // ONEFLOW_CORE_EAGER_LAZY_JOB_PHY_INSTR_OPERAND_H_
oneflow/core/eager/local_dep_object.h
View file @
a715222c
...
@@ -20,12 +20,16 @@ limitations under the License.
...
@@ -20,12 +20,16 @@ limitations under the License.
#include "oneflow/core/vm/vm_object.h"
#include "oneflow/core/vm/vm_object.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/symbol.h"
#include "oneflow/core/common/symbol.h"
#include "oneflow/core/common/small_vector.h"
#include "oneflow/core/common/op_args_reserved_size.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/device.h"
namespace
oneflow
{
namespace
oneflow
{
// LocalDepObject helps VirtualMachineEngine building instruction edges
// LocalDepObject helps VirtualMachineEngine building instruction edges
using
LocalDepObject
=
vm
::
MirroredObject
;
using
LocalDepObject
=
vm
::
Dependence
;
using
DependenceVector
=
small_vector
<
LocalDepObject
*
,
kOpArgsReservedSize
>
;
intrusive
::
shared_ptr
<
LocalDepObject
>
NewLocalDepObject
();
intrusive
::
shared_ptr
<
LocalDepObject
>
NewLocalDepObject
();
...
...
oneflow/core/eager/op_call_instruction_type.cpp
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/device_type.pb.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/protobuf.h"
#ifdef WITH_ROCM
#include "oneflow/core/ep/rocm/cuda_stream.h"
#else
#include "oneflow/core/ep/cuda/cuda_stream.h"
#endif
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/eager/eager_blob_object.h"
#include "oneflow/core/vm/stream.h"
#include "oneflow/core/vm/allocator.h"
#include "oneflow/core/vm/thread_ctx.h"
#include "oneflow/core/eager/op_call_instruction_type.h"
#include "oneflow/core/eager/op_call_phy_instr_operand.h"
#include "oneflow/core/vm/instruction.h"
#include "oneflow/core/vm/instruction_type.h"
#include "oneflow/core/framework/user_op_registry_manager.h"
#include "oneflow/core/job/foreign_callback.h"
#include "oneflow/core/register/ofblob.h"
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/operator/op_conf_symbol.h"
#include "oneflow/user/kernels/stateful_opkernel.h"
#include "oneflow/core/profiler/profiler.h"
#include "oneflow/core/profiler/profile_manager.h"
#include "oneflow/core/profiler/event_recorder.h"
#include "oneflow/core/common/cpp_attribute.h"
namespace
oneflow
{
namespace
vm
{
struct
OpCallInstructionUtil
final
{
static
inline
Maybe
<
void
>
Prepare
(
const
vm
::
Instruction
&
instruction
)
{
auto
*
operand
=
GetCallPhyInstrOperand
(
instruction
);
DeviceCtx
*
device_ctx
=
instruction
.
stream
().
device_ctx
().
get
();
JUST
(
AllocateOutputBlobsMemory
(
operand
,
device_ctx
));
if
(
unlikely
(
operand
->
need_temp_storage
()))
{
InferTempStorageSize
(
operand
);
JUST
(
TryAllocateTempStorage
(
operand
,
device_ctx
));
// Since memory block is cached in allocator, it's safe to deallocate tmp buffer before
// kernel executed.
DeallocateTempStorage
(
operand
,
device_ctx
);
}
return
Maybe
<
void
>::
Ok
();
}
static
inline
void
Compute
(
const
vm
::
Instruction
&
instruction
)
{
auto
*
operand
=
GetCallPhyInstrOperand
(
instruction
);
DeviceCtx
*
device_ctx
=
instruction
.
stream
().
device_ctx
().
get
();
if
(
!
operand
->
is_all_outputs_pod
())
{
for
(
const
auto
&
blob_object
:
*
operand
->
outputs
())
{
blob_object
->
TryInitNonPODTypeEagerBlobObjectIfNeed
();
}
}
user_op
::
OpKernelState
*
state
=
nullptr
;
user_op
::
OpKernelCache
*
cache
=
nullptr
;
if
(
operand
->
user_opkernel
()
->
has_state_or_cache
())
{
TryInitOpKernelStateAndCache
(
operand
,
device_ctx
,
&
state
,
&
cache
);
}
OpKernelCompute
(
operand
,
device_ctx
,
state
,
cache
);
}
static
inline
OpCallPhyInstrOperand
*
GetCallPhyInstrOperand
(
const
vm
::
Instruction
&
instruction
)
{
auto
*
operand
=
CHECK_NOTNULL
(
instruction
.
phy_instr_operand
().
get
());
return
CHECK_NOTNULL
(
dynamic_cast
<
OpCallPhyInstrOperand
*>
(
operand
));
}
private:
static
inline
void
InferTempStorageSize
(
OpCallPhyInstrOperand
*
operand
)
{
auto
*
tmp_tensor
=
operand
->
mut_call_ctx
()
->
mut_tmp_tensor
();
size_t
temp_size
=
operand
->
opkernel
().
InferTmpSize
(
&
operand
->
call_ctx_
,
operand
->
user_opkernel
());
tmp_tensor
->
set_tmp_buffer_size
(
temp_size
);
}
static
inline
void
TryInitOpKernelStateAndCache
(
OpCallPhyInstrOperand
*
operand
,
DeviceCtx
*
device_ctx
,
user_op
::
OpKernelState
**
state
,
user_op
::
OpKernelCache
**
cache
)
{
OF_PROFILER_RANGE_GUARD
(
"TryInitOpKernelStateAndCache"
);
if
(
likely
(
operand
->
op_interp_ctx
().
state
))
{
*
state
=
operand
->
op_interp_ctx
().
state
.
get
();
// set state to nullptr so that state initialization in TryInitOpKernelStateAndCache will be
// skipped.
state
=
nullptr
;
}
operand
->
mut_opkernel
()
->
TryInitOpKernelStateAndCache
(
&
operand
->
call_ctx_
,
device_ctx
,
operand
->
user_opkernel
(),
state
,
cache
);
}
static
inline
Maybe
<
void
>
AllocateOutputBlobsMemory
(
OpCallPhyInstrOperand
*
operand
,
DeviceCtx
*
device_ctx
)
{
OF_PROFILER_RANGE_GUARD
(
"AllocateOutputBlobsMemory"
);
for
(
const
auto
&
blob_object
:
*
operand
->
outputs
())
{
JUST
(
blob_object
->
TryAllocateBlobBodyMemory
(
device_ctx
));
}
return
Maybe
<
void
>::
Ok
();
}
static
inline
Maybe
<
void
>
TryAllocateTempStorage
(
OpCallPhyInstrOperand
*
operand
,
DeviceCtx
*
device_ctx
)
{
OF_PROFILER_RANGE_GUARD
(
"TryAllocateTempStorage"
);
auto
*
tmp_tensor
=
operand
->
mut_call_ctx
()
->
mut_tmp_tensor
();
size_t
byte_size
=
tmp_tensor
->
tmp_buffer_size
();
if
(
byte_size
>
0
)
{
char
*
mem_ptr
=
nullptr
;
JUST
(
device_ctx
->
mut_allocator
()
->
Allocate
(
&
mem_ptr
,
byte_size
));
tmp_tensor
->
init_tmp_buffer_ptr
(
mem_ptr
);
}
return
Maybe
<
void
>::
Ok
();
}
static
inline
void
OpKernelCompute
(
OpCallPhyInstrOperand
*
operand
,
DeviceCtx
*
device_ctx
,
user_op
::
OpKernelState
*
state
,
user_op
::
OpKernelCache
*
cache
)
{
auto
*
call_ctx
=
&
operand
->
call_ctx_
;
auto
*
user_kernel
=
operand
->
user_opkernel
();
operand
->
mut_opkernel
()
->
Compute
(
call_ctx
,
device_ctx
,
user_kernel
,
state
,
cache
);
}
static
inline
void
DeallocateTempStorage
(
OpCallPhyInstrOperand
*
operand
,
DeviceCtx
*
device_ctx
)
{
OF_PROFILER_RANGE_GUARD
(
"DeallocateTempStorage"
);
auto
*
tmp_tensor
=
operand
->
mut_call_ctx
()
->
mut_tmp_tensor
();
device_ctx
->
mut_allocator
()
->
Deallocate
(
tmp_tensor
->
mut_tmp_buffer_ptr
(),
tmp_tensor
->
tmp_buffer_size
());
}
};
Maybe
<
void
>
OpCallInstructionType
::
Prepare
(
vm
::
Instruction
*
instruction
)
const
{
return
OpCallInstructionUtil
::
Prepare
(
*
instruction
);
}
void
OpCallInstructionType
::
Compute
(
vm
::
Instruction
*
instruction
)
const
{
OpCallInstructionUtil
::
Compute
(
*
instruction
);
}
std
::
string
OpCallInstructionType
::
DebugName
(
const
vm
::
Instruction
&
instruction
)
const
{
auto
*
operand
=
CHECK_NOTNULL
(
instruction
.
phy_instr_operand
().
get
());
return
CHECK_NOTNULL
(
dynamic_cast
<
OpCallPhyInstrOperand
*>
(
operand
))
->
opkernel
().
op_type_name
()
+
":OpCall"
;
}
}
// namespace vm
}
// namespace oneflow
oneflow/core/eager/op_call_instruction_type.h
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EAGER_OP_CALL_INSTRUCTION_TYPE_H_
#define ONEFLOW_CORE_EAGER_OP_CALL_INSTRUCTION_TYPE_H_
#include "oneflow/core/vm/instruction.h"
#include "oneflow/core/vm/instruction_type.h"
#include "oneflow/core/memory/memory_case.pb.h"
namespace
oneflow
{
namespace
vm
{
class
OpCallInstructionType
final
:
public
vm
::
InstructionType
{
public:
OpCallInstructionType
()
=
default
;
~
OpCallInstructionType
()
=
default
;
Maybe
<
void
>
Prepare
(
vm
::
Instruction
*
instruction
)
const
override
;
void
Compute
(
vm
::
Instruction
*
instruction
)
const
override
;
InstructionFuseType
fuse_type
()
const
override
{
return
kEnableInstructionFuseAtAnyPosition
;
}
std
::
string
DebugName
(
const
vm
::
Instruction
&
instruction
)
const
override
;
protected:
private:
Maybe
<
void
>
MaybeCompute
(
vm
::
Instruction
*
instruction
)
const
;
};
}
// namespace vm
}
// namespace oneflow
#endif // ONEFLOW_CORE_EAGER_OP_CALL_INSTRUCTION_TYPE_H_
oneflow/core/eager/op_call_phy_instr_operand.cpp
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/eager/op_call_phy_instr_operand.h"
#include "oneflow/user/kernels/stateful_opkernel.h"
#include "oneflow/core/eager/dev_vm_dep_object_consume_mode.h"
#include "oneflow/core/framework/stream_is_comm_net_stream.h"
#include "oneflow/core/vm/stream.h"
namespace
oneflow
{
namespace
vm
{
OpCallPhyInstrOperand
::
OpCallPhyInstrOperand
(
vm
::
Stream
*
vm_stream
,
const
std
::
shared_ptr
<
one
::
StatefulOpKernel
>&
opkernel
,
const
one
::
EagerBlobObjectListPtr
&
inputs
,
const
one
::
EagerBlobObjectListPtr
&
outputs
,
const
std
::
shared_ptr
<
const
one
::
ConsistentTensorInferResult
>&
consistent_tensor_infer_result
,
const
one
::
OpExprInterpContext
&
op_interp_ctx
,
const
one
::
DevVmDepObjectConsumeMode
dev_vm_dep_object_consume_mode
)
:
vm_stream_
(
vm_stream
),
call_ctx_
(
ComposedAttrMap
(
op_interp_ctx
.
attrs
,
opkernel
->
base_attrs
()),
inputs
,
outputs
,
consistent_tensor_infer_result
,
op_interp_ctx
,
opkernel
->
mem_case
()),
opkernel_
(
opkernel
),
user_opkernel_
(
nullptr
),
infer_tmp_size_fn_
(
nullptr
),
need_temp_storage_
(
false
),
dev_vm_dep_object_consume_mode_
(
dev_vm_dep_object_consume_mode
),
input_dependences_
(),
output_dependences_
(),
is_all_outputs_pod_
(
false
)
{
ForEachConstMirroredObject
(
SetInserter
(
&
input_dependences_
));
ForEachMutMirroredObject
(
SetInserter
(
&
output_dependences_
));
ForEachMut2MirroredObject
(
SetInserter
(
&
output_dependences_
));
InitStreamSequentialDependence
();
for
(
const
auto
&
blob_object
:
*
outputs
)
{
is_all_outputs_pod_
=
is_all_outputs_pod_
&&
IsPODDataType
(
blob_object
->
data_type
());
}
}
Maybe
<
void
>
OpCallPhyInstrOperand
::
Init
()
{
return
mut_opkernel
()
->
ChooseOpKernel
(
&
call_ctx_
,
&
user_opkernel_
,
&
need_temp_storage_
);
}
void
OpCallPhyInstrOperand
::
ForEachConstMirroredObject
(
const
std
::
function
<
void
(
vm
::
MirroredObject
*
compute
)
>&
DoEach
)
const
{
const
auto
&
input_list
=
inputs
();
for
(
int64_t
index
:
opkernel
().
input_tuple_indexes4const_ibns
())
{
const
auto
&
input
=
input_list
->
at
(
index
);
DoEach
(
CHECK_JUST
(
input
->
compute_local_dep_object
()));
}
}
void
OpCallPhyInstrOperand
::
InitStreamSequentialDependence
()
{
auto
*
device_schedule_dep_object
=
vm_stream_
->
schedule_local_dep_object
().
get
();
if
(
IsCommNetStream
::
Visit
(
vm_stream_
->
stream_role
()))
{
// Sequantialize nccl instructions to avoid deadlock
stream_sequential_dependence_
=
device_schedule_dep_object
;
}
else
{
// Sequantialize instructions to avoid explosive memory allocation of source ops
if
(
dev_vm_dep_object_consume_mode
()
==
one
::
DevVmDepObjectConsumeMode
::
MUTABLE
)
{
stream_sequential_dependence_
=
device_schedule_dep_object
;
}
else
if
(
opkernel
().
input_tuple_indexes4const_ibns
().
empty
()
&&
opkernel
().
input_tuple_indexes4mut_ibns
().
empty
())
{
stream_sequential_dependence_
=
device_schedule_dep_object
;
}
}
}
void
OpCallPhyInstrOperand
::
ForEachMutMirroredObject
(
const
std
::
function
<
void
(
vm
::
MirroredObject
*
compute
)
>&
DoEach
)
const
{
const
auto
&
opt_transport_dep_object
=
vm_stream_
->
transport_local_dep_object
();
if
(
opt_transport_dep_object
.
has_value
())
{
DoEach
(
CHECK_JUST
(
opt_transport_dep_object
)
->
get
());
}
const
auto
&
input_list
=
inputs
();
for
(
int64_t
index
:
opkernel
().
input_tuple_indexes4mut_ibns
())
{
const
auto
&
input
=
input_list
->
at
(
index
);
DoEach
(
CHECK_JUST
(
input
->
compute_local_dep_object
()));
}
const
auto
&
output_list
=
outputs
();
for
(
int64_t
index
:
opkernel
().
output_tuple_indexes4mut_obns
())
{
const
auto
&
output
=
output_list
->
at
(
index
);
DoEach
(
CHECK_JUST
(
output
->
compute_local_dep_object
()));
}
}
void
OpCallPhyInstrOperand
::
ForEachMut2MirroredObject
(
const
std
::
function
<
void
(
vm
::
MirroredObject
*
compute
)
>&
DoEach
)
const
{
const
auto
&
output_list
=
outputs
();
for
(
int64_t
index
:
opkernel
().
output_tuple_indexes4mut2_obns
())
{
const
auto
&
output
=
output_list
->
at
(
index
);
DoEach
(
CHECK_JUST
(
output
->
compute_local_dep_object
()));
}
}
}
// namespace vm
}
// namespace oneflow
oneflow/core/eager/op_call_phy_instr_operand.h
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EAGER_OP_CALL_PHY_INSTR_OPERAND_H_
#define ONEFLOW_CORE_EAGER_OP_CALL_PHY_INSTR_OPERAND_H_
#include "oneflow/core/vm/phy_instr_operand.h"
#include "oneflow/core/eager/call_context.h"
#include "oneflow/core/eager/dev_vm_dep_object_consume_mode.h"
#include "oneflow/core/framework/user_op_kernel_registry.h"
namespace
oneflow
{
namespace
user_op
{
class
OpKernel
;
}
// namespace user_op
namespace
vm
{
class
Stream
;
struct
OpCallInstructionUtil
;
class
OpCallPhyInstrOperand
final
:
public
vm
::
PhyInstrOperand
{
public:
OpCallPhyInstrOperand
(
const
OpCallPhyInstrOperand
&
)
=
delete
;
OpCallPhyInstrOperand
(
OpCallPhyInstrOperand
&&
)
=
delete
;
~
OpCallPhyInstrOperand
()
override
=
default
;
template
<
typename
...
Args
>
static
Maybe
<
OpCallPhyInstrOperand
>
New
(
Args
&&
...
args
)
{
auto
*
ptr
=
new
OpCallPhyInstrOperand
(
std
::
forward
<
Args
>
(
args
)...);
JUST
(
ptr
->
Init
());
return
std
::
shared_ptr
<
OpCallPhyInstrOperand
>
(
ptr
);
}
const
one
::
StatefulOpKernel
&
opkernel
()
const
{
return
*
opkernel_
;
}
const
one
::
EagerBlobObjectListPtr
&
inputs
()
const
{
return
call_ctx_
.
inputs
();
}
const
one
::
EagerBlobObjectListPtr
&
outputs
()
const
{
return
call_ctx_
.
outputs
();
}
const
AttrMap
&
attrs
()
const
{
return
call_ctx_
.
op_interp_ctx
().
attrs
;
}
const
one
::
OpExprInterpContext
&
op_interp_ctx
()
const
{
return
call_ctx_
.
op_interp_ctx
();
}
const
one
::
DevVmDepObjectConsumeMode
&
dev_vm_dep_object_consume_mode
()
const
{
return
dev_vm_dep_object_consume_mode_
;
}
bool
is_all_outputs_pod
()
const
{
return
is_all_outputs_pod_
;
}
one
::
StatefulOpKernel
*
mut_opkernel
()
{
return
opkernel_
.
get
();
}
template
<
typename
DoEachT
>
Maybe
<
void
>
ForEachOutputTensor
(
const
DoEachT
&
DoEach
)
{
for
(
const
auto
&
output
:
*
outputs
())
{
JUST
(
DoEach
(
output
.
get
()));
}
return
Maybe
<
void
>::
Ok
();
}
const
DependenceVector
&
input_dependences
()
const
override
{
return
input_dependences_
;
}
const
DependenceVector
&
output_dependences
()
const
override
{
return
output_dependences_
;
}
void
ForEachConstMirroredObject
(
const
std
::
function
<
void
(
vm
::
MirroredObject
*
compute
)
>&
)
const
;
void
ForEachMutMirroredObject
(
const
std
::
function
<
void
(
vm
::
MirroredObject
*
compute
)
>&
)
const
;
void
ForEachMut2MirroredObject
(
const
std
::
function
<
void
(
vm
::
MirroredObject
*
compute
)
>&
)
const
;
bool
need_temp_storage
()
const
{
return
need_temp_storage_
;
}
const
user_op
::
OpKernel
*
user_opkernel
()
const
{
return
user_opkernel_
;
}
const
user_op
::
InferTmpSizeFn
&
infer_tmp_size_fn
()
const
{
return
*
infer_tmp_size_fn_
;
}
const
std
::
shared_ptr
<
const
one
::
ConsistentTensorInferResult
>&
consistent_tensor_infer_result
()
const
{
return
call_ctx_
.
consistent_tensor_infer_result
();
}
eager
::
CallContext
*
mut_call_ctx
()
{
return
&
call_ctx_
;
}
void
ForEachInputEagerBlobObjects
(
void
(
*
DoEach
)(
EagerBlobObject
*
))
const
override
{
for
(
const
auto
&
eager_blob_object
:
*
call_ctx_
.
inputs
())
{
DoEach
(
eager_blob_object
.
get
());
}
}
private:
friend
struct
OpCallInstructionUtil
;
OpCallPhyInstrOperand
(
vm
::
Stream
*
vm_stream
,
const
std
::
shared_ptr
<
one
::
StatefulOpKernel
>&
opkernel
,
const
one
::
EagerBlobObjectListPtr
&
inputs
,
const
one
::
EagerBlobObjectListPtr
&
outputs
,
const
std
::
shared_ptr
<
const
one
::
ConsistentTensorInferResult
>&
consistent_tensor_infer_result
,
const
one
::
OpExprInterpContext
&
op_interp_ctx
,
const
one
::
DevVmDepObjectConsumeMode
dev_vm_dep_object_consume_mode
);
Maybe
<
void
>
Init
();
void
InitStreamSequentialDependence
();
vm
::
Stream
*
vm_stream_
;
eager
::
CallContext
call_ctx_
;
std
::
shared_ptr
<
one
::
StatefulOpKernel
>
opkernel_
;
const
user_op
::
OpKernel
*
user_opkernel_
;
const
user_op
::
InferTmpSizeFn
*
infer_tmp_size_fn_
;
bool
need_temp_storage_
;
const
one
::
DevVmDepObjectConsumeMode
dev_vm_dep_object_consume_mode_
;
DependenceVector
input_dependences_
;
DependenceVector
output_dependences_
;
bool
is_all_outputs_pod_
;
};
}
// namespace vm
}
// namespace oneflow
#endif // ONEFLOW_CORE_EAGER_OP_CALL_PHY_INSTR_OPERAND_H_
oneflow/core/eager/release_tensor_arg_phy_instr_operand.h
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EAGER_RELEASE_TENSOR_ARG_PHY_INSTR_OPERAND_H_
#define ONEFLOW_CORE_EAGER_RELEASE_TENSOR_ARG_PHY_INSTR_OPERAND_H_
#include <functional>
#include <memory>
#include "oneflow/core/intrusive/intrusive.h"
#include "oneflow/core/vm/phy_instr_operand.h"
#include "oneflow/core/eager/local_dep_object.h"
#include "oneflow/core/eager/eager_blob_object.h"
#include "oneflow/core/common/symbol.h"
#include "oneflow/core/common/optional.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/stream.h"
#include "oneflow/core/vm/stream.h"
namespace
oneflow
{
namespace
vm
{
class
EagerBlobObject
;
class
ReleaseTensorArgPhyInstrOperand
:
public
PhyInstrOperand
{
public:
ReleaseTensorArgPhyInstrOperand
(
const
std
::
shared_ptr
<
vm
::
EagerBlobObject
>&
eager_blob_object
,
const
Optional
<
vm
::
Stream
*>&
stream
)
:
eager_blob_object_
(
eager_blob_object
),
output_dependences_
()
{
output_dependences_
.
push_back
(
CHECK_JUST
(
eager_blob_object
->
compute_local_dep_object
()));
if
(
stream
.
has_value
())
{
stream_sequential_dependence_
=
CHECK_JUST
(
stream
)
->
schedule_local_dep_object
().
get
();
}
}
~
ReleaseTensorArgPhyInstrOperand
()
override
=
default
;
const
std
::
shared_ptr
<
vm
::
EagerBlobObject
>&
eager_blob_object
()
const
{
return
eager_blob_object_
;
}
const
DependenceVector
&
input_dependences
()
const
override
{
static
thread_local
DependenceVector
empty
{};
return
empty
;
}
const
DependenceVector
&
output_dependences
()
const
override
{
return
output_dependences_
;
}
void
ForEachInputEagerBlobObjects
(
void
(
*
DoEach
)(
EagerBlobObject
*
))
const
override
{
DoEach
(
eager_blob_object_
.
get
());
}
private:
std
::
shared_ptr
<
vm
::
EagerBlobObject
>
eager_blob_object_
;
DependenceVector
output_dependences_
;
};
}
// namespace vm
}
// namespace oneflow
#endif // ONEFLOW_CORE_EAGER_RELEASE_TENSOR_ARG_PHY_INSTR_OPERAND_H_
oneflow/core/eager/release_tensor_instruction_type.h
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EAGER_RELEASE_TENSOR_INSTRUCTION_TYPE_H_
#define ONEFLOW_CORE_EAGER_RELEASE_TENSOR_INSTRUCTION_TYPE_H_
#include "oneflow/core/vm/instruction.h"
#include "oneflow/core/vm/instruction_type.h"
#include "oneflow/core/vm/ep_optional_event_record_status_querier.h"
#include "oneflow/core/eager/release_tensor_arg_phy_instr_operand.h"
#include "oneflow/core/eager/eager_blob_object.h"
#include "oneflow/core/common/stream_role.h"
#include "oneflow/core/common/singleton_ptr.h"
namespace
oneflow
{
namespace
vm
{
class
ReleaseTensorInstructionType
:
public
vm
::
InstructionType
{
public:
ReleaseTensorInstructionType
()
=
default
;
~
ReleaseTensorInstructionType
()
override
=
default
;
InstructionFuseType
fuse_type
()
const
override
{
return
kEnableInstructionFuseAtAnyPosition
;
}
std
::
string
DebugName
(
const
vm
::
Instruction
&
instruction
)
const
override
{
return
"ReleaseTensor"
;
}
Maybe
<
void
>
Prepare
(
vm
::
Instruction
*
instruction
)
const
override
{
const
auto
&
eager_blob_object
=
GetEagerBlobObject
(
*
instruction
);
DataType
data_type
=
eager_blob_object
->
data_type
();
if
(
IsPODDataType
(
data_type
))
{
Release
(
eager_blob_object
);
}
return
Maybe
<
void
>::
Ok
();
}
void
Compute
(
vm
::
Instruction
*
instruction
)
const
override
{
const
auto
&
eager_blob_object
=
GetEagerBlobObject
(
*
instruction
);
DataType
data_type
=
eager_blob_object
->
data_type
();
if
(
!
IsPODDataType
(
data_type
))
{
Release
(
eager_blob_object
);
}
}
void
InitInstructionStatus
(
Instruction
*
instruction
)
const
override
{
auto
*
status_buffer
=
instruction
->
mut_status_buffer
();
auto
*
stream
=
instruction
->
mut_stream
();
instruction
->
stream_type
().
InitInstructionStatus
(
*
stream
,
status_buffer
);
auto
*
data_ptr
=
status_buffer
->
mut_buffer
();
EpOptionalEventRecordStatusQuerier
::
MutCast
(
data_ptr
)
->
reset_ep_event
(
nullptr
);
}
private:
const
std
::
shared_ptr
<
vm
::
EagerBlobObject
>&
GetEagerBlobObject
(
const
vm
::
Instruction
&
instruction
)
const
{
const
auto
&
phy_instr_operand
=
instruction
.
phy_instr_operand
();
CHECK
(
static_cast
<
bool
>
(
phy_instr_operand
));
const
auto
*
ptr
=
dynamic_cast
<
const
vm
::
ReleaseTensorArgPhyInstrOperand
*>
(
phy_instr_operand
.
get
());
CHECK_NOTNULL
(
ptr
);
return
ptr
->
eager_blob_object
();
}
void
Release
(
const
std
::
shared_ptr
<
vm
::
EagerBlobObject
>&
eager_blob_object
)
const
{
CHECK_JUST
(
eager_blob_object
->
DeallocateBlobDataPtr
());
}
};
}
// namespace vm
struct
GetReleaseInstructionType
:
public
StreamRoleVisitor
<
GetReleaseInstructionType
>
{
static
Maybe
<
const
vm
::
InstructionType
*>
VisitCompute
(
DeviceType
device_type
)
{
return
SingletonPtr
<
vm
::
ReleaseTensorInstructionType
>
();
}
static
Maybe
<
const
vm
::
InstructionType
*>
VisitHost2Device
(
DeviceType
device_type
)
{
return
SingletonPtr
<
vm
::
ReleaseTensorInstructionType
>
();
}
static
Maybe
<
const
vm
::
InstructionType
*>
VisitDevice2Host
(
DeviceType
device_type
)
{
return
SingletonPtr
<
vm
::
ReleaseTensorInstructionType
>
();
}
static
Maybe
<
const
vm
::
InstructionType
*>
VisitSyncedLaunchedCommNet
(
DeviceType
device_type
)
{
return
SingletonPtr
<
vm
::
ReleaseTensorInstructionType
>
();
}
static
Maybe
<
const
vm
::
InstructionType
*>
VisitAsyncedLaunchedCommNet
(
DeviceType
device_type
)
{
return
SingletonPtr
<
vm
::
ReleaseTensorInstructionType
>
();
}
static
Maybe
<
const
vm
::
InstructionType
*>
VisitBarrier
(
DeviceType
device_type
)
{
UNIMPLEMENTED_THEN_RETURN
();
}
static
Maybe
<
const
vm
::
InstructionType
*>
VisitCriticalSection
(
DeviceType
device_type
)
{
UNIMPLEMENTED_THEN_RETURN
();
}
static
Maybe
<
const
vm
::
InstructionType
*>
VisitLazyJobLauncher
(
DeviceType
device_type
)
{
UNIMPLEMENTED_THEN_RETURN
();
}
static
Maybe
<
const
vm
::
InstructionType
*>
VisitPinnedCompute
(
DeviceType
device_type
)
{
return
VisitCompute
(
device_type
);
}
};
}
// namespace oneflow
#endif // ONEFLOW_CORE_EAGER_RELEASE_TENSOR_INSTRUCTION_TYPE_H_
oneflow/core/embedding/cache.h
View file @
a715222c
...
@@ -75,6 +75,9 @@ class Cache {
...
@@ -75,6 +75,9 @@ class Cache {
}
}
virtual
void
Dump
(
ep
::
Stream
*
stream
,
uint64_t
start_key_index
,
uint64_t
end_key_index
,
virtual
void
Dump
(
ep
::
Stream
*
stream
,
uint64_t
start_key_index
,
uint64_t
end_key_index
,
uint32_t
*
n_dumped
,
void
*
keys
,
void
*
values
)
=
0
;
uint32_t
*
n_dumped
,
void
*
keys
,
void
*
values
)
=
0
;
virtual
void
ClearDirtyFlags
()
=
0
;
virtual
void
Clear
()
=
0
;
virtual
void
Clear
()
=
0
;
};
};
...
...
oneflow/core/embedding/cache_test.cpp
View file @
a715222c
...
@@ -462,7 +462,7 @@ TEST(Cache, FullCache) {
...
@@ -462,7 +462,7 @@ TEST(Cache, FullCache) {
// TestCache(cache.get(), line_size);
// TestCache(cache.get(), line_size);
// }
// }
#endif
#endif
// WITH_ROCM
}
// namespace
}
// namespace
...
...
oneflow/core/embedding/cached_key_value_store.cu
View file @
a715222c
...
@@ -45,22 +45,26 @@ class CacheKeyValueStoreImpl : public KeyValueStore {
...
@@ -45,22 +45,26 @@ class CacheKeyValueStoreImpl : public KeyValueStore {
OF_DISALLOW_COPY_AND_MOVE
(
CacheKeyValueStoreImpl
);
OF_DISALLOW_COPY_AND_MOVE
(
CacheKeyValueStoreImpl
);
CacheKeyValueStoreImpl
(
std
::
unique_ptr
<
KeyValueStore
>&&
store
,
std
::
unique_ptr
<
Cache
>&&
cache
)
CacheKeyValueStoreImpl
(
std
::
unique_ptr
<
KeyValueStore
>&&
store
,
std
::
unique_ptr
<
Cache
>&&
cache
)
:
store_
(
std
::
move
(
store
)),
cache_
(
std
::
move
(
cache
)),
synced_
(
true
),
max_query_length_
(
0
)
{
:
store_
(
std
::
move
(
store
)),
cache_
(
std
::
move
(
cache
)),
synced_
(
true
),
max_query_length_
(
0
)
{
OF_CUDA_CHECK
(
cuda
GetDevice
(
&
device_index_
));
OF_CUDA_CHECK
(
GPU
(
GetDevice
)
(
&
device_index_
));
CHECK_EQ
(
store_
->
KeySize
(),
cache_
->
KeySize
());
CHECK_EQ
(
store_
->
KeySize
(),
cache_
->
KeySize
());
CHECK_EQ
(
store_
->
ValueSize
(),
cache_
->
ValueSize
());
CHECK_EQ
(
store_
->
ValueSize
(),
cache_
->
ValueSize
());
OF_CUDA_CHECK
(
cudaMalloc
(
&
num_buffer_
,
sizeof
(
uint32_t
)));
OF_CUDA_CHECK
(
GPU
(
Malloc
)(
&
num_buffer_
,
sizeof
(
uint32_t
)));
#ifdef WITH_ROCM
OF_CUDA_CHECK
(
hipMallocHost
(
reinterpret_cast
<
void
**>
(
&
host_num_buffer_
),
sizeof
(
uint32_t
)));
#else
OF_CUDA_CHECK
(
cudaMallocHost
(
&
host_num_buffer_
,
sizeof
(
uint32_t
)));
OF_CUDA_CHECK
(
cudaMallocHost
(
&
host_num_buffer_
,
sizeof
(
uint32_t
)));
#endif
num_elems_per_value_
=
store_
->
ValueSize
()
/
sizeof
(
Elem
);
num_elems_per_value_
=
store_
->
ValueSize
()
/
sizeof
(
Elem
);
}
}
~
CacheKeyValueStoreImpl
()
{
~
CacheKeyValueStoreImpl
()
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
CudaCurrentDeviceGuard
guard
(
device_index_
);
OF_CUDA_CHECK
(
cuda
Free
(
num_buffer_
));
OF_CUDA_CHECK
(
GPU
(
Free
)
(
num_buffer_
));
OF_CUDA_CHECK
(
cuda
FreeHost
(
host_num_buffer_
));
OF_CUDA_CHECK
(
GPU
(
FreeHost
)
(
host_num_buffer_
));
if
(
max_query_length_
!=
0
)
{
if
(
max_query_length_
!=
0
)
{
OF_CUDA_CHECK
(
cuda
Free
(
keys_buffer_
));
OF_CUDA_CHECK
(
GPU
(
Free
)
(
keys_buffer_
));
OF_CUDA_CHECK
(
cuda
Free
(
values_buffer_
));
OF_CUDA_CHECK
(
GPU
(
Free
)
(
values_buffer_
));
OF_CUDA_CHECK
(
cuda
Free
(
indices_buffer0_
));
OF_CUDA_CHECK
(
GPU
(
Free
)
(
indices_buffer0_
));
OF_CUDA_CHECK
(
cuda
Free
(
indices_buffer1_
));
OF_CUDA_CHECK
(
GPU
(
Free
)
(
indices_buffer1_
));
}
}
cache_
.
reset
();
cache_
.
reset
();
store_
.
reset
();
store_
.
reset
();
...
@@ -76,15 +80,15 @@ class CacheKeyValueStoreImpl : public KeyValueStore {
...
@@ -76,15 +80,15 @@ class CacheKeyValueStoreImpl : public KeyValueStore {
if
(
query_length
>
cache_
->
MaxQueryLength
())
{
cache_
->
ReserveQueryLength
(
query_length
);
}
if
(
query_length
>
cache_
->
MaxQueryLength
())
{
cache_
->
ReserveQueryLength
(
query_length
);
}
if
(
query_length
>
store_
->
MaxQueryLength
())
{
store_
->
ReserveQueryLength
(
query_length
);
}
if
(
query_length
>
store_
->
MaxQueryLength
())
{
store_
->
ReserveQueryLength
(
query_length
);
}
if
(
max_query_length_
!=
0
)
{
if
(
max_query_length_
!=
0
)
{
OF_CUDA_CHECK
(
cuda
Free
(
keys_buffer_
));
OF_CUDA_CHECK
(
GPU
(
Free
)
(
keys_buffer_
));
OF_CUDA_CHECK
(
cuda
Free
(
values_buffer_
));
OF_CUDA_CHECK
(
GPU
(
Free
)
(
values_buffer_
));
OF_CUDA_CHECK
(
cuda
Free
(
indices_buffer0_
));
OF_CUDA_CHECK
(
GPU
(
Free
)
(
indices_buffer0_
));
OF_CUDA_CHECK
(
cuda
Free
(
indices_buffer1_
));
OF_CUDA_CHECK
(
GPU
(
Free
)
(
indices_buffer1_
));
}
}
OF_CUDA_CHECK
(
cuda
Malloc
(
&
keys_buffer_
,
query_length
*
store_
->
KeySize
()));
OF_CUDA_CHECK
(
GPU
(
Malloc
)
(
&
keys_buffer_
,
query_length
*
store_
->
KeySize
()));
OF_CUDA_CHECK
(
cuda
Malloc
(
&
values_buffer_
,
query_length
*
store_
->
ValueSize
()));
OF_CUDA_CHECK
(
GPU
(
Malloc
)
(
&
values_buffer_
,
query_length
*
store_
->
ValueSize
()));
OF_CUDA_CHECK
(
cuda
Malloc
(
&
indices_buffer0_
,
query_length
*
sizeof
(
uint32_t
)));
OF_CUDA_CHECK
(
GPU
(
Malloc
)
(
&
indices_buffer0_
,
query_length
*
sizeof
(
uint32_t
)));
OF_CUDA_CHECK
(
cuda
Malloc
(
&
indices_buffer1_
,
query_length
*
sizeof
(
uint32_t
)));
OF_CUDA_CHECK
(
GPU
(
Malloc
)
(
&
indices_buffer1_
,
query_length
*
sizeof
(
uint32_t
)));
max_query_length_
=
query_length
;
max_query_length_
=
query_length
;
}
}
...
@@ -136,17 +140,17 @@ void CacheKeyValueStoreImpl<Key, Elem>::Get(ep::Stream* stream, uint32_t num_key
...
@@ -136,17 +140,17 @@ void CacheKeyValueStoreImpl<Key, Elem>::Get(ep::Stream* stream, uint32_t num_key
}
else
{
}
else
{
cache_
->
Get
(
stream
,
num_keys
,
keys
,
values
,
num_buffer_
,
keys_buffer_
,
indices_buffer0_
);
cache_
->
Get
(
stream
,
num_keys
,
keys
,
values
,
num_buffer_
,
keys_buffer_
,
indices_buffer0_
);
}
}
OF_CUDA_CHECK
(
cuda
MemcpyAsync
(
host_num_buffer_
,
num_buffer_
,
sizeof
(
uint32_t
),
cuda
MemcpyDefault
,
OF_CUDA_CHECK
(
GPU
(
MemcpyAsync
)
(
host_num_buffer_
,
num_buffer_
,
sizeof
(
uint32_t
),
GPU
(
MemcpyDefault
)
,
cuda_stream
->
cuda_stream
()));
cuda_stream
->
cuda_stream
()));
CHECK_JUST
(
cuda_stream
->
Sync
());
CHECK_JUST
(
cuda_stream
->
Sync
());
const
uint32_t
num_cache_missing
=
*
host_num_buffer_
;
const
uint32_t
num_cache_missing
=
*
host_num_buffer_
;
if
(
num_cache_missing
==
0
)
{
if
(
num_cache_missing
==
0
)
{
OF_CUDA_CHECK
(
cuda
MemsetAsync
(
n_missing
,
0
,
sizeof
(
uint32_t
),
OF_CUDA_CHECK
(
GPU
(
MemsetAsync
)
(
n_missing
,
0
,
sizeof
(
uint32_t
),
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()));
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()));
return
;
return
;
}
}
store_
->
Get
(
stream
,
num_cache_missing
,
keys_buffer_
,
values_buffer_
,
n_missing
,
indices_buffer1_
);
store_
->
Get
(
stream
,
num_cache_missing
,
keys_buffer_
,
values_buffer_
,
n_missing
,
indices_buffer1_
);
OF_CUDA_CHECK
(
cuda
MemcpyAsync
(
host_num_buffer_
,
n_missing
,
sizeof
(
uint32_t
),
cuda
MemcpyDefault
,
OF_CUDA_CHECK
(
GPU
(
MemcpyAsync
)
(
host_num_buffer_
,
n_missing
,
sizeof
(
uint32_t
),
GPU
(
MemcpyDefault
)
,
cuda_stream
->
cuda_stream
()));
cuda_stream
->
cuda_stream
()));
CHECK_JUST
(
cuda_stream
->
Sync
());
CHECK_JUST
(
cuda_stream
->
Sync
());
const
uint32_t
num_store_missing
=
*
host_num_buffer_
;
const
uint32_t
num_store_missing
=
*
host_num_buffer_
;
...
@@ -173,9 +177,12 @@ void CacheKeyValueStoreImpl<Key, Elem>::Put(ep::Stream* stream, uint32_t num_key
...
@@ -173,9 +177,12 @@ void CacheKeyValueStoreImpl<Key, Elem>::Put(ep::Stream* stream, uint32_t num_key
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
(
mutex_
);
synced_
=
false
;
synced_
=
false
;
auto
cuda_stream
=
stream
->
As
<
ep
::
CudaStream
>
();
auto
cuda_stream
=
stream
->
As
<
ep
::
CudaStream
>
();
if
(
cache_
->
Policy
()
!=
CacheOptions
::
Policy
::
kFull
)
{
OF_CUDA_CHECK
(
GPU
(
MemsetAsync
)(
num_buffer_
,
0
,
sizeof
(
uint32_t
),
cuda_stream
->
cuda_stream
()));
}
cache_
->
Put
(
stream
,
num_keys
,
keys
,
values
,
num_buffer_
,
keys_buffer_
,
values_buffer_
);
cache_
->
Put
(
stream
,
num_keys
,
keys
,
values
,
num_buffer_
,
keys_buffer_
,
values_buffer_
);
if
(
cache_
->
Policy
()
==
CacheOptions
::
Policy
::
kFull
)
{
return
;
}
if
(
cache_
->
Policy
()
==
CacheOptions
::
Policy
::
kFull
)
{
return
;
}
OF_CUDA_CHECK
(
cuda
MemcpyAsync
(
host_num_buffer_
,
num_buffer_
,
sizeof
(
uint32_t
),
cuda
MemcpyDefault
,
OF_CUDA_CHECK
(
GPU
(
MemcpyAsync
)
(
host_num_buffer_
,
num_buffer_
,
sizeof
(
uint32_t
),
GPU
(
MemcpyDefault
)
,
cuda_stream
->
cuda_stream
()));
cuda_stream
->
cuda_stream
()));
CHECK_JUST
(
cuda_stream
->
Sync
());
CHECK_JUST
(
cuda_stream
->
Sync
());
store_
->
Put
(
stream
,
*
host_num_buffer_
,
keys_buffer_
,
values_buffer_
);
store_
->
Put
(
stream
,
*
host_num_buffer_
,
keys_buffer_
,
values_buffer_
);
...
@@ -187,6 +194,10 @@ void CacheKeyValueStoreImpl<Key, Elem>::FusedHalfUpdatePut(ep::Stream* stream, u
...
@@ -187,6 +194,10 @@ void CacheKeyValueStoreImpl<Key, Elem>::FusedHalfUpdatePut(ep::Stream* stream, u
const
void
*
update
,
const
float
*
lr
,
const
void
*
update
,
const
float
*
lr
,
float
scale
)
{
float
scale
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
(
mutex_
);
if
(
cache_
->
Policy
()
!=
CacheOptions
::
Policy
::
kFull
)
{
OF_CUDA_CHECK
(
GPU
(
MemsetAsync
)(
num_buffer_
,
0
,
sizeof
(
uint32_t
),
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()));
}
if
(
cache_
->
Policy
()
!=
CacheOptions
::
Policy
::
kFull
||
cache_
->
ValueType
()
!=
DataType
::
kFloat
)
{
if
(
cache_
->
Policy
()
!=
CacheOptions
::
Policy
::
kFull
||
cache_
->
ValueType
()
!=
DataType
::
kFloat
)
{
UNIMPLEMENTED
();
UNIMPLEMENTED
();
}
}
...
@@ -221,17 +232,13 @@ void CacheKeyValueStoreImpl<Key, Elem>::LoadSnapshot(
...
@@ -221,17 +232,13 @@ void CacheKeyValueStoreImpl<Key, Elem>::LoadSnapshot(
auto
*
cuda_stream
=
stream
->
As
<
ep
::
CudaStream
>
();
auto
*
cuda_stream
=
stream
->
As
<
ep
::
CudaStream
>
();
while
(
true
)
{
while
(
true
)
{
iter
->
NextN
(
stream
,
max_query_length_
,
num_buffer_
,
keys_buffer_
,
values_buffer_
);
iter
->
NextN
(
stream
,
max_query_length_
,
num_buffer_
,
keys_buffer_
,
values_buffer_
);
OF_CUDA_CHECK
(
cuda
DeviceSynchronize
());
OF_CUDA_CHECK
(
GPU
(
DeviceSynchronize
)
());
OF_CUDA_CHECK
(
cuda
MemcpyAsync
(
host_num_buffer_
,
num_buffer_
,
sizeof
(
uint32_t
),
OF_CUDA_CHECK
(
GPU
(
MemcpyAsync
)
(
host_num_buffer_
,
num_buffer_
,
sizeof
(
uint32_t
),
cuda
MemcpyDefault
,
cuda_stream
->
cuda_stream
()));
GPU
(
MemcpyDefault
)
,
cuda_stream
->
cuda_stream
()));
CHECK_JUST
(
stream
->
Sync
());
CHECK_JUST
(
stream
->
Sync
());
if
(
*
host_num_buffer_
==
0
)
{
return
;
}
if
(
*
host_num_buffer_
==
0
)
{
return
;
}
cache_
->
Put
(
stream
,
*
host_num_buffer_
,
keys_buffer_
,
values_buffer_
,
num_buffer_
,
nullptr
,
cache_
->
Put
(
stream
,
*
host_num_buffer_
,
keys_buffer_
,
values_buffer_
,
num_buffer_
,
nullptr
,
nullptr
);
nullptr
);
OF_CUDA_CHECK
(
cudaMemcpyAsync
(
host_num_buffer_
,
num_buffer_
,
sizeof
(
uint32_t
),
cudaMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
CHECK_JUST
(
stream
->
Sync
());
CHECK_EQ
(
*
host_num_buffer_
,
0
);
}
}
}
}
if
(
Hook
)
{
if
(
Hook
)
{
...
@@ -267,13 +274,14 @@ void CacheKeyValueStoreImpl<Key, Elem>::SyncCacheToStore() {
...
@@ -267,13 +274,14 @@ void CacheKeyValueStoreImpl<Key, Elem>::SyncCacheToStore() {
cache_
->
Dump
(
stream
,
start_key_index
,
cache_
->
Dump
(
stream
,
start_key_index
,
std
::
min
(
start_key_index
+
max_query_length_
,
dump_capacity
),
num_buffer_
,
std
::
min
(
start_key_index
+
max_query_length_
,
dump_capacity
),
num_buffer_
,
keys_buffer_
,
values_buffer_
);
keys_buffer_
,
values_buffer_
);
OF_CUDA_CHECK
(
cuda
MemcpyAsync
(
host_num_buffer_
,
num_buffer_
,
sizeof
(
uint32_t
),
OF_CUDA_CHECK
(
GPU
(
MemcpyAsync
)
(
host_num_buffer_
,
num_buffer_
,
sizeof
(
uint32_t
),
cuda
MemcpyDefault
,
cuda_stream
->
cuda_stream
()));
GPU
(
MemcpyDefault
)
,
cuda_stream
->
cuda_stream
()));
CHECK_JUST
(
stream
->
Sync
());
CHECK_JUST
(
stream
->
Sync
());
if
(
*
host_num_buffer_
==
0
)
{
continue
;
}
if
(
*
host_num_buffer_
==
0
)
{
continue
;
}
store_
->
Put
(
stream
,
*
host_num_buffer_
,
keys_buffer_
,
values_buffer_
);
store_
->
Put
(
stream
,
*
host_num_buffer_
,
keys_buffer_
,
values_buffer_
);
CHECK_JUST
(
stream
->
Sync
());
CHECK_JUST
(
stream
->
Sync
());
}
}
cache_
->
ClearDirtyFlags
();
device
->
DestroyStream
(
stream
);
device
->
DestroyStream
(
stream
);
synced_
=
true
;
synced_
=
true
;
}
}
...
...
oneflow/core/embedding/cached_key_value_store.hip.cpp
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "hip/hip_runtime.h"
#include "oneflow/core/embedding/cached_key_value_store.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/ep/include/device_manager_registry.h"
namespace
oneflow
{
namespace
embedding
{
namespace
{
template
<
typename
Key
,
typename
Elem
>
__global__
void
PostStoreGetKernel
(
uint32_t
num_cache_missing
,
uint32_t
num_store_missing
,
uint32_t
num_elems_per_value
,
const
uint32_t
*
cache_missing_indices
,
const
uint32_t
*
store_missing_indices
,
const
Elem
*
store_values
,
Elem
*
values
,
uint32_t
*
missing_indices
)
{
const
uint32_t
num_cache_missing_elem
=
num_cache_missing
*
num_elems_per_value
;
CUDA_1D_KERNEL_LOOP_T
(
uint32_t
,
i
,
num_cache_missing_elem
)
{
const
uint32_t
value_index
=
i
/
num_elems_per_value
;
const
uint32_t
elem_index
=
i
-
value_index
*
num_elems_per_value
;
values
[
cache_missing_indices
[
value_index
]
*
num_elems_per_value
+
elem_index
]
=
store_values
[
i
];
}
CUDA_1D_KERNEL_LOOP_T
(
uint32_t
,
i
,
num_store_missing
)
{
missing_indices
[
i
]
=
cache_missing_indices
[
store_missing_indices
[
i
]];
}
}
template
<
typename
Key
,
typename
Elem
>
class
CacheKeyValueStoreImpl
:
public
KeyValueStore
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
CacheKeyValueStoreImpl
);
CacheKeyValueStoreImpl
(
std
::
unique_ptr
<
KeyValueStore
>&&
store
,
std
::
unique_ptr
<
Cache
>&&
cache
)
:
store_
(
std
::
move
(
store
)),
cache_
(
std
::
move
(
cache
)),
synced_
(
true
),
max_query_length_
(
0
)
{
OF_CUDA_CHECK
(
hipGetDevice
(
&
device_index_
));
CHECK_EQ
(
store_
->
KeySize
(),
cache_
->
KeySize
());
CHECK_EQ
(
store_
->
ValueSize
(),
cache_
->
ValueSize
());
OF_CUDA_CHECK
(
hipMalloc
(
&
num_buffer_
,
sizeof
(
uint32_t
)));
OF_CUDA_CHECK
(
hipMallocHost
(
reinterpret_cast
<
void
**>
(
&
host_num_buffer_
),
sizeof
(
uint32_t
)));
num_elems_per_value_
=
store_
->
ValueSize
()
/
sizeof
(
Elem
);
}
~
CacheKeyValueStoreImpl
()
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
OF_CUDA_CHECK
(
hipFree
(
num_buffer_
));
OF_CUDA_CHECK
(
hipHostFree
(
host_num_buffer_
));
if
(
max_query_length_
!=
0
)
{
OF_CUDA_CHECK
(
hipFree
(
keys_buffer_
));
OF_CUDA_CHECK
(
hipFree
(
values_buffer_
));
OF_CUDA_CHECK
(
hipFree
(
indices_buffer0_
));
OF_CUDA_CHECK
(
hipFree
(
indices_buffer1_
));
}
cache_
.
reset
();
store_
.
reset
();
}
uint32_t
KeySize
()
const
override
{
return
store_
->
KeySize
();
}
uint32_t
ValueSize
()
const
override
{
return
store_
->
ValueSize
();
}
uint32_t
MaxQueryLength
()
const
override
{
return
max_query_length_
;
}
void
ReserveQueryLength
(
uint32_t
query_length
)
override
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
if
(
query_length
<=
max_query_length_
)
{
return
;
}
if
(
query_length
>
cache_
->
MaxQueryLength
())
{
cache_
->
ReserveQueryLength
(
query_length
);
}
if
(
query_length
>
store_
->
MaxQueryLength
())
{
store_
->
ReserveQueryLength
(
query_length
);
}
if
(
max_query_length_
!=
0
)
{
OF_CUDA_CHECK
(
hipFree
(
keys_buffer_
));
OF_CUDA_CHECK
(
hipFree
(
values_buffer_
));
OF_CUDA_CHECK
(
hipFree
(
indices_buffer0_
));
OF_CUDA_CHECK
(
hipFree
(
indices_buffer1_
));
}
OF_CUDA_CHECK
(
hipMalloc
(
&
keys_buffer_
,
query_length
*
store_
->
KeySize
()));
OF_CUDA_CHECK
(
hipMalloc
(
&
values_buffer_
,
query_length
*
store_
->
ValueSize
()));
OF_CUDA_CHECK
(
hipMalloc
(
&
indices_buffer0_
,
query_length
*
sizeof
(
uint32_t
)));
OF_CUDA_CHECK
(
hipMalloc
(
&
indices_buffer1_
,
query_length
*
sizeof
(
uint32_t
)));
max_query_length_
=
query_length
;
}
void
Get
(
ep
::
Stream
*
stream
,
uint32_t
num_keys
,
const
void
*
keys
,
void
*
values
,
uint32_t
*
n_missing
,
uint32_t
*
missing_indices
)
override
;
void
Get
(
ep
::
Stream
*
stream
,
uint32_t
num_keys
,
const
void
*
keys
,
void
*
values
,
uint8_t
*
mask
)
override
;
void
Put
(
ep
::
Stream
*
stream
,
uint32_t
num_keys
,
const
void
*
keys
,
const
void
*
values
)
override
;
void
FusedHalfUpdatePut
(
ep
::
Stream
*
stream
,
uint32_t
n_keys
,
const
void
*
keys
,
const
void
*
values
,
const
void
*
update
,
const
float
*
lr
,
float
scale
)
override
;
bool
IsFusionSupported
()
override
{
return
cache_
->
Policy
()
==
CacheOptions
::
Policy
::
kFull
&&
cache_
->
ValueType
()
==
DataType
::
kFloat
;
}
bool
SnapshotExists
(
const
std
::
string
&
name
)
override
;
void
LoadSnapshot
(
const
std
::
string
&
name
)
override
;
void
SaveSnapshot
(
const
std
::
string
&
name
)
override
;
void
LoadSnapshot
(
const
std
::
string
&
name
,
const
std
::
function
<
void
(
KVIterator
*
iter
)
>&
Hook
)
override
;
private:
void
SyncCacheToStore
();
std
::
unique_ptr
<
KeyValueStore
>
store_
;
std
::
unique_ptr
<
Cache
>
cache_
;
uint32_t
*
num_buffer_
{};
uint32_t
*
host_num_buffer_
{};
Key
*
keys_buffer_
{};
Elem
*
values_buffer_
{};
uint32_t
*
indices_buffer0_
{};
uint32_t
*
indices_buffer1_
{};
int
device_index_
{};
uint32_t
max_query_length_
;
uint32_t
num_elems_per_value_
{};
std
::
recursive_mutex
mutex_
;
bool
synced_
;
};
template
<
typename
Key
,
typename
Elem
>
void
CacheKeyValueStoreImpl
<
Key
,
Elem
>::
Get
(
ep
::
Stream
*
stream
,
uint32_t
num_keys
,
const
void
*
keys
,
void
*
values
,
uint32_t
*
n_missing
,
uint32_t
*
missing_indices
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
(
mutex_
);
auto
cuda_stream
=
stream
->
As
<
ep
::
CudaStream
>
();
if
(
cache_
->
Policy
()
==
CacheOptions
::
Policy
::
kFull
)
{
cache_
->
Get
(
stream
,
num_keys
,
keys
,
values
,
n_missing
,
keys_buffer_
,
missing_indices
);
return
;
}
else
{
cache_
->
Get
(
stream
,
num_keys
,
keys
,
values
,
num_buffer_
,
keys_buffer_
,
indices_buffer0_
);
}
OF_CUDA_CHECK
(
hipMemcpyAsync
(
host_num_buffer_
,
num_buffer_
,
sizeof
(
uint32_t
),
hipMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
CHECK_JUST
(
cuda_stream
->
Sync
());
const
uint32_t
num_cache_missing
=
*
host_num_buffer_
;
if
(
num_cache_missing
==
0
)
{
OF_CUDA_CHECK
(
hipMemsetAsync
(
n_missing
,
0
,
sizeof
(
uint32_t
),
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()));
return
;
}
store_
->
Get
(
stream
,
num_cache_missing
,
keys_buffer_
,
values_buffer_
,
n_missing
,
indices_buffer1_
);
OF_CUDA_CHECK
(
hipMemcpyAsync
(
host_num_buffer_
,
n_missing
,
sizeof
(
uint32_t
),
hipMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
CHECK_JUST
(
cuda_stream
->
Sync
());
const
uint32_t
num_store_missing
=
*
host_num_buffer_
;
RUN_CUDA_KERNEL
((
PostStoreGetKernel
<
Key
,
Elem
>
),
stream
,
num_cache_missing
*
num_elems_per_value_
,
num_cache_missing
,
num_store_missing
,
num_elems_per_value_
,
indices_buffer0_
,
indices_buffer1_
,
values_buffer_
,
static_cast
<
Elem
*>
(
values
),
missing_indices
);
}
template
<
typename
Key
,
typename
Elem
>
void
CacheKeyValueStoreImpl
<
Key
,
Elem
>::
Get
(
ep
::
Stream
*
stream
,
uint32_t
num_keys
,
const
void
*
keys
,
void
*
values
,
uint8_t
*
mask
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
(
mutex_
);
if
(
cache_
->
Policy
()
==
CacheOptions
::
Policy
::
kFull
)
{
cache_
->
Get
(
stream
,
num_keys
,
keys
,
values
,
mask
);
return
;
}
else
{
UNIMPLEMENTED
();
}
}
template
<
typename
Key
,
typename
Elem
>
void
CacheKeyValueStoreImpl
<
Key
,
Elem
>::
Put
(
ep
::
Stream
*
stream
,
uint32_t
num_keys
,
const
void
*
keys
,
const
void
*
values
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
(
mutex_
);
synced_
=
false
;
auto
cuda_stream
=
stream
->
As
<
ep
::
CudaStream
>
();
cache_
->
Put
(
stream
,
num_keys
,
keys
,
values
,
num_buffer_
,
keys_buffer_
,
values_buffer_
);
if
(
cache_
->
Policy
()
==
CacheOptions
::
Policy
::
kFull
)
{
return
;
}
OF_CUDA_CHECK
(
hipMemcpyAsync
(
host_num_buffer_
,
num_buffer_
,
sizeof
(
uint32_t
),
hipMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
CHECK_JUST
(
cuda_stream
->
Sync
());
store_
->
Put
(
stream
,
*
host_num_buffer_
,
keys_buffer_
,
values_buffer_
);
}
template
<
typename
Key
,
typename
Elem
>
void
CacheKeyValueStoreImpl
<
Key
,
Elem
>::
FusedHalfUpdatePut
(
ep
::
Stream
*
stream
,
uint32_t
num_keys
,
const
void
*
keys
,
const
void
*
values
,
const
void
*
update
,
const
float
*
lr
,
float
scale
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
(
mutex_
);
if
(
cache_
->
Policy
()
!=
CacheOptions
::
Policy
::
kFull
||
cache_
->
ValueType
()
!=
DataType
::
kFloat
)
{
UNIMPLEMENTED
();
}
synced_
=
false
;
cache_
->
FusedHalfUpdatePut
(
stream
,
num_keys
,
keys
,
values
,
update
,
lr
,
scale
,
num_buffer_
,
keys_buffer_
,
values_buffer_
);
}
template
<
typename
Key
,
typename
Elem
>
bool
CacheKeyValueStoreImpl
<
Key
,
Elem
>::
SnapshotExists
(
const
std
::
string
&
name
)
{
return
store_
->
SnapshotExists
(
name
);
}
template
<
typename
Key
,
typename
Elem
>
void
CacheKeyValueStoreImpl
<
Key
,
Elem
>::
LoadSnapshot
(
const
std
::
string
&
name
)
{
LoadSnapshot
(
name
,
nullptr
);
}
template
<
typename
Key
,
typename
Elem
>
void
CacheKeyValueStoreImpl
<
Key
,
Elem
>::
LoadSnapshot
(
const
std
::
string
&
name
,
const
std
::
function
<
void
(
KVIterator
*
iter
)
>&
Hook
)
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
(
mutex_
);
CHECK_GT
(
max_query_length_
,
0
);
cache_
->
Clear
();
auto
device
=
Singleton
<
ep
::
DeviceManagerRegistry
>::
Get
()
->
GetDevice
(
DeviceType
::
kCUDA
,
device_index_
);
CHECK
(
device
);
auto
*
stream
=
device
->
CreateStream
();
store_
->
LoadSnapshot
(
name
,
[
&
](
KVIterator
*
iter
)
{
if
(
cache_
->
Policy
()
==
CacheOptions
::
Policy
::
kFull
)
{
auto
*
cuda_stream
=
stream
->
As
<
ep
::
CudaStream
>
();
while
(
true
)
{
iter
->
NextN
(
stream
,
max_query_length_
,
num_buffer_
,
keys_buffer_
,
values_buffer_
);
OF_CUDA_CHECK
(
hipDeviceSynchronize
());
OF_CUDA_CHECK
(
hipMemcpyAsync
(
host_num_buffer_
,
num_buffer_
,
sizeof
(
uint32_t
),
hipMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
CHECK_JUST
(
stream
->
Sync
());
if
(
*
host_num_buffer_
==
0
)
{
return
;
}
cache_
->
Put
(
stream
,
*
host_num_buffer_
,
keys_buffer_
,
values_buffer_
,
num_buffer_
,
nullptr
,
nullptr
);
OF_CUDA_CHECK
(
hipMemcpyAsync
(
host_num_buffer_
,
num_buffer_
,
sizeof
(
uint32_t
),
hipMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
CHECK_JUST
(
stream
->
Sync
());
CHECK_EQ
(
*
host_num_buffer_
,
0
);
}
}
if
(
Hook
)
{
iter
->
Reset
();
Hook
(
iter
);
}
});
device
->
DestroyStream
(
stream
);
store_
->
LoadSnapshot
(
name
);
}
template
<
typename
Key
,
typename
Elem
>
void
CacheKeyValueStoreImpl
<
Key
,
Elem
>::
SaveSnapshot
(
const
std
::
string
&
name
)
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
std
::
lock_guard
<
std
::
recursive_mutex
>
lock
(
mutex_
);
SyncCacheToStore
();
store_
->
SaveSnapshot
(
name
);
}
template
<
typename
Key
,
typename
Elem
>
void
CacheKeyValueStoreImpl
<
Key
,
Elem
>::
SyncCacheToStore
()
{
if
(
synced_
)
{
return
;
}
CudaCurrentDeviceGuard
guard
(
device_index_
);
auto
device
=
Singleton
<
ep
::
DeviceManagerRegistry
>::
Get
()
->
GetDevice
(
DeviceType
::
kCUDA
,
device_index_
);
CHECK
(
device
);
auto
*
stream
=
device
->
CreateStream
();
auto
*
cuda_stream
=
stream
->
As
<
ep
::
CudaStream
>
();
const
uint64_t
dump_capacity
=
cache_
->
DumpCapacity
();
CHECK_GT
(
max_query_length_
,
0
);
for
(
uint64_t
start_key_index
=
0
;
start_key_index
<
dump_capacity
;
start_key_index
+=
max_query_length_
)
{
cache_
->
Dump
(
stream
,
start_key_index
,
std
::
min
(
start_key_index
+
max_query_length_
,
dump_capacity
),
num_buffer_
,
keys_buffer_
,
values_buffer_
);
OF_CUDA_CHECK
(
hipMemcpyAsync
(
host_num_buffer_
,
num_buffer_
,
sizeof
(
uint32_t
),
hipMemcpyDefault
,
cuda_stream
->
cuda_stream
()));
CHECK_JUST
(
stream
->
Sync
());
if
(
*
host_num_buffer_
==
0
)
{
continue
;
}
store_
->
Put
(
stream
,
*
host_num_buffer_
,
keys_buffer_
,
values_buffer_
);
CHECK_JUST
(
stream
->
Sync
());
}
device
->
DestroyStream
(
stream
);
synced_
=
true
;
}
template
<
typename
Key
>
std
::
unique_ptr
<
KeyValueStore
>
DispatchElemType
(
std
::
unique_ptr
<
KeyValueStore
>&&
store
,
std
::
unique_ptr
<
Cache
>&&
cache
)
{
const
uint32_t
value_size
=
store
->
ValueSize
();
if
(
value_size
%
sizeof
(
uint4
)
==
0
)
{
return
std
::
unique_ptr
<
KeyValueStore
>
(
new
CacheKeyValueStoreImpl
<
Key
,
uint4
>
(
std
::
move
(
store
),
std
::
move
(
cache
)));
}
else
if
(
value_size
%
sizeof
(
uint64_t
)
==
0
)
{
return
std
::
unique_ptr
<
KeyValueStore
>
(
new
CacheKeyValueStoreImpl
<
Key
,
uint64_t
>
(
std
::
move
(
store
),
std
::
move
(
cache
)));
}
else
if
(
value_size
%
sizeof
(
uint32_t
)
==
0
)
{
return
std
::
unique_ptr
<
KeyValueStore
>
(
new
CacheKeyValueStoreImpl
<
Key
,
uint32_t
>
(
std
::
move
(
store
),
std
::
move
(
cache
)));
}
else
if
(
value_size
%
sizeof
(
uint16_t
)
==
0
)
{
return
std
::
unique_ptr
<
KeyValueStore
>
(
new
CacheKeyValueStoreImpl
<
Key
,
uint16_t
>
(
std
::
move
(
store
),
std
::
move
(
cache
)));
}
else
{
return
std
::
unique_ptr
<
KeyValueStore
>
(
new
CacheKeyValueStoreImpl
<
Key
,
uint8_t
>
(
std
::
move
(
store
),
std
::
move
(
cache
)));
}
}
std
::
unique_ptr
<
KeyValueStore
>
DispatchKeyType
(
std
::
unique_ptr
<
KeyValueStore
>&&
store
,
std
::
unique_ptr
<
Cache
>&&
cache
)
{
const
uint32_t
key_size
=
store
->
KeySize
();
if
(
key_size
==
4
)
{
return
DispatchElemType
<
uint32_t
>
(
std
::
move
(
store
),
std
::
move
(
cache
));
}
else
if
(
key_size
==
8
)
{
return
DispatchElemType
<
uint64_t
>
(
std
::
move
(
store
),
std
::
move
(
cache
));
}
else
{
UNIMPLEMENTED
();
return
nullptr
;
}
}
}
// namespace
std
::
unique_ptr
<
KeyValueStore
>
NewCachedKeyValueStore
(
std
::
unique_ptr
<
KeyValueStore
>&&
store
,
std
::
unique_ptr
<
Cache
>&&
cache
)
{
return
DispatchKeyType
(
std
::
move
(
store
),
std
::
move
(
cache
));
}
}
// namespace embedding
}
// namespace oneflow
\ No newline at end of file
Prev
1
…
15
16
17
18
19
20
21
22
23
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment