Commit a715222c authored by yuguo's avatar yuguo
Browse files

0.9.1-rocm

parent f262efc9
...@@ -21,17 +21,14 @@ limitations under the License. ...@@ -21,17 +21,14 @@ limitations under the License.
#include "oneflow/core/framework/op_interpreter.h" #include "oneflow/core/framework/op_interpreter.h"
#include "oneflow/core/common/shape_view.h" #include "oneflow/core/common/shape_view.h"
#include "oneflow/core/common/stride.h" #include "oneflow/core/common/stride.h"
#include "oneflow/core/common/small_vector.h"
namespace oneflow { namespace oneflow {
namespace one { namespace one {
class StatefulLocalOpKernel; class StatefulLocalOpKernel;
class ConsistentTensorInferResult; class GlobalTensorInferResult;
using EagerBlobObjectList = std::vector<std::shared_ptr<vm::EagerBlobObject>>;
using EagerBlobObjectListPtr =
std::shared_ptr<const std::vector<std::shared_ptr<vm::EagerBlobObject>>>;
} // namespace one } // namespace one
...@@ -60,10 +57,7 @@ class TmpTensor final : public user_op::Tensor { ...@@ -60,10 +57,7 @@ class TmpTensor final : public user_op::Tensor {
char* mut_tmp_buffer_ptr() { return tmp_buffer_ptr_; } char* mut_tmp_buffer_ptr() { return tmp_buffer_ptr_; }
void init_tmp_buffer_ptr(char* ptr) { void set_tmp_buffer_ptr(char* ptr) { tmp_buffer_ptr_ = ptr; }
CHECK_EQ(tmp_buffer_ptr_, nullptr);
tmp_buffer_ptr_ = ptr;
}
private: private:
std::shared_ptr<MemoryCase> mem_case_; std::shared_ptr<MemoryCase> mem_case_;
...@@ -73,35 +67,34 @@ class TmpTensor final : public user_op::Tensor { ...@@ -73,35 +67,34 @@ class TmpTensor final : public user_op::Tensor {
class CallContext { class CallContext {
public: public:
CallContext( CallContext(ComposedAttrMap&& composed_attrs, vm::EagerBlobObjectList&& inputs,
ComposedAttrMap&& composed_attrs, const one::EagerBlobObjectListPtr& inputs, vm::EagerBlobObjectList&& outputs,
const one::EagerBlobObjectListPtr& outputs, const std::shared_ptr<const one::GlobalTensorInferResult>& global_tensor_infer_result,
const std::shared_ptr<const one::ConsistentTensorInferResult>& consistent_tensor_infer_result, const one::OpExprInterpContext& op_interp_ctx,
const one::OpExprInterpContext& op_interp_ctx, const std::shared_ptr<MemoryCase>& mem_case) const std::shared_ptr<MemoryCase>& mem_case)
: composed_attrs_(std::move(composed_attrs)), : composed_attrs_(std::move(composed_attrs)),
inputs_(inputs), inputs_(std::move(inputs)),
outputs_(outputs), outputs_(std::move(outputs)),
consistent_tensor_infer_result_(consistent_tensor_infer_result), global_tensor_infer_result_(global_tensor_infer_result),
op_interp_ctx_(op_interp_ctx), op_interp_ctx_(op_interp_ctx),
tmp_tensor_(mem_case) {} tmp_tensor_(mem_case) {}
~CallContext() = default; ~CallContext() = default;
const ComposedAttrMap& composed_attrs() const { return composed_attrs_; } const ComposedAttrMap& composed_attrs() const { return composed_attrs_; }
const one::EagerBlobObjectListPtr& inputs() const { return inputs_; } const vm::EagerBlobObjectList& inputs() const { return inputs_; }
const one::EagerBlobObjectListPtr& outputs() const { return outputs_; } const vm::EagerBlobObjectList& outputs() const { return outputs_; }
const std::shared_ptr<const one::ConsistentTensorInferResult>& consistent_tensor_infer_result() const std::shared_ptr<const one::GlobalTensorInferResult>& global_tensor_infer_result() const {
const { return global_tensor_infer_result_;
return consistent_tensor_infer_result_;
} }
const one::OpExprInterpContext& op_interp_ctx() const { return op_interp_ctx_; } const one::OpExprInterpContext& op_interp_ctx() const { return op_interp_ctx_; }
TmpTensor* mut_tmp_tensor() { return &tmp_tensor_; } TmpTensor* mut_tmp_tensor() { return &tmp_tensor_; }
private: private:
const ComposedAttrMap composed_attrs_; const ComposedAttrMap composed_attrs_;
const one::EagerBlobObjectListPtr inputs_; const vm::EagerBlobObjectList inputs_;
const one::EagerBlobObjectListPtr outputs_; const vm::EagerBlobObjectList outputs_;
const std::shared_ptr<const one::ConsistentTensorInferResult> consistent_tensor_infer_result_; const std::shared_ptr<const one::GlobalTensorInferResult> global_tensor_infer_result_;
const one::OpExprInterpContext op_interp_ctx_; const one::OpExprInterpContext op_interp_ctx_;
TmpTensor tmp_tensor_; TmpTensor tmp_tensor_;
}; };
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EAGER_CRITICAL_SECTION_INSTRUCTION_TYPE_H_
#define ONEFLOW_CORE_EAGER_CRITICAL_SECTION_INSTRUCTION_TYPE_H_
#include "oneflow/core/vm/critical_section_status_querier.h"
#include "oneflow/core/eager/critical_section_phy_instr_operand.h"
#include "oneflow/core/job/critical_section_instance.h"
#include "oneflow/core/framework/nn_graph_if.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/vm/instruction.h"
#include "oneflow/core/vm/instruction_type.h"
#include "oneflow/core/common/buffer_manager.h"
#include "oneflow/core/common/singleton.h"
#include "oneflow/core/vm/stream.h"
#include "oneflow/core/vm/thread_ctx.h"
#include "oneflow/core/register/ofblob.h"
#include "oneflow/core/vm/ref_cnt_instruction_status_querier.h"
#include "oneflow/core/profiler/profiler.h"
namespace oneflow {
namespace vm {
class CriticalSectionBeginInstructionType final : public InstructionType {
public:
CriticalSectionBeginInstructionType(const CriticalSectionBeginInstructionType&) = delete;
CriticalSectionBeginInstructionType(CriticalSectionBeginInstructionType&&) = delete;
CriticalSectionBeginInstructionType& operator=(const CriticalSectionBeginInstructionType&) =
delete;
CriticalSectionBeginInstructionType& operator=(CriticalSectionBeginInstructionType&&) = delete;
CriticalSectionBeginInstructionType() = default;
~CriticalSectionBeginInstructionType() = default;
std::string DebugName(const vm::Instruction& instruction) const override {
return "CriticalSectionBegin";
}
Maybe<void> Prepare(vm::Instruction* instruction) const override { return Maybe<void>::Ok(); }
void Compute(vm::Instruction* instruction) const override {
OF_PROFILER_RANGE_GUARD("CriticalSectionBegin");
{
auto ptr = instruction->phy_instr_operand();
auto phy_instr_operand = std::dynamic_pointer_cast<CriticalSectionBeginPhyInstrOperand>(ptr);
CHECK_NOTNULL(phy_instr_operand);
const auto& critical_section_instance = MakeCriticalSectionInstance(phy_instr_operand);
const auto& job_name = critical_section_instance->job_name();
auto* buffer_mgr = Singleton<BufferMgr<std::shared_ptr<CriticalSectionInstance>>>::Get();
for (int i = 0; i < phy_instr_operand->interfaces_op_names().size(); ++i) {
if (phy_instr_operand->interfaces_valid().at(i)) {
const std::string& interface_op_name = phy_instr_operand->interfaces_op_names().at(i);
const auto& buffer_name =
phy_instr_operand->GetInterfaceBufferName(job_name, interface_op_name);
buffer_mgr->Get(buffer_name)->Push(critical_section_instance);
}
}
const auto& callback_buffer_name =
phy_instr_operand->GetInterfaceCriticalSectionCallbackBufferName(job_name);
buffer_mgr->Get(callback_buffer_name)->Push(critical_section_instance);
const auto& wait_buffer_name =
phy_instr_operand->GetInterfaceCriticalSectionWaitBufferName(job_name);
buffer_mgr->Get(wait_buffer_name)->Push(critical_section_instance);
}
{
auto* status_buffer_data = instruction->mut_status_buffer()->mut_buffer();
auto* status_querier = CriticalSectionStatusQuerier::MutCast(status_buffer_data);
status_querier->SetLaunched(std::make_shared<NaiveEventRecord>());
}
}
private:
class NaiveCriticalSectionInstance final : public CriticalSectionInstance {
public:
NaiveCriticalSectionInstance(
const std::shared_ptr<CriticalSectionBeginPhyInstrOperand>& phy_instr_operand,
const std::string& job_name)
: CriticalSectionInstance(), phy_instr_operand_(phy_instr_operand), job_name_(job_name) {}
~NaiveCriticalSectionInstance() override = default;
const std::string& job_name() const override { return job_name_; }
void AccessBlobByOpName(uint64_t ofblob_ptr, const std::string& op_name) const override {
phy_instr_operand_->AccessBlobByOpName(ofblob_ptr, op_name);
}
void Finish() const override { phy_instr_operand_->Finish(); }
private:
std::shared_ptr<CriticalSectionBeginPhyInstrOperand> phy_instr_operand_;
std::string job_name_;
};
std::shared_ptr<CriticalSectionInstance> MakeCriticalSectionInstance(
const std::shared_ptr<CriticalSectionBeginPhyInstrOperand>& phy_instr_operand) const {
phy_instr_operand->FinishInvalidInterfaceEventRecords();
const auto& job_name = phy_instr_operand->nn_graph()->job_name();
return std::make_shared<NaiveCriticalSectionInstance>(phy_instr_operand, job_name);
}
};
class CriticalSectionEndInstructionType final : public InstructionType {
public:
CriticalSectionEndInstructionType(const CriticalSectionEndInstructionType&) = delete;
CriticalSectionEndInstructionType(CriticalSectionEndInstructionType&&) = delete;
CriticalSectionEndInstructionType& operator=(const CriticalSectionEndInstructionType&) = delete;
CriticalSectionEndInstructionType& operator=(CriticalSectionEndInstructionType&&) = delete;
CriticalSectionEndInstructionType() = default;
~CriticalSectionEndInstructionType() = default;
std::string DebugName(const vm::Instruction& instruction) const override {
return "CriticalSectionEnd";
}
Maybe<void> Prepare(vm::Instruction* instruction) const override { return Maybe<void>::Ok(); }
void Compute(vm::Instruction* instruction) const override {
const auto* ptr = instruction->phy_instr_operand().get();
const auto* phy_instr_operand = dynamic_cast<const CriticalSectionEndPhyInstrOperand*>(ptr);
CHECK_NOTNULL(phy_instr_operand);
auto* status_buffer_data = instruction->mut_status_buffer()->mut_buffer();
auto* status_querier = CriticalSectionStatusQuerier::MutCast(status_buffer_data);
status_querier->SetLaunched(phy_instr_operand->event_record());
}
};
} // namespace vm
} // namespace oneflow
#endif // ONEFLOW_CORE_EAGER_CRITICAL_SECTION_INSTRUCTION_TYPE_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/eager/critical_section_phy_instr_operand.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/stream.h"
#include "oneflow/core/kernel/kernel_util.h"
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/device/device_context.h"
#include "oneflow/core/device/ep_based_event_record.h"
#include "oneflow/core/register/ofblob.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/vm/stream.h"
namespace oneflow {
namespace vm {
void CriticalSectionBeginPhyInstrOperand::ForEachMirroredObject(
const std::function<void(vm::MirroredObject* compute)>& DoEach) const {
for (const auto& eager_blob_object : *eager_blob_objects_) {
DoEach(CHECK_JUST(eager_blob_object->compute_local_dep_object()));
}
}
void CriticalSectionEndPhyInstrOperand::ForEachMirroredObject(
const std::function<void(vm::MirroredObject* compute)>& DoEach) const {
DoEach(CHECK_JUST(eager_blob_object_->compute_local_dep_object()));
}
void CriticalSectionBeginPhyInstrOperand::ForEachMutMirroredObject(
const std::function<void(vm::MirroredObject* compute)>& DoEach) const {
DoEach(vm_stream_->schedule_local_dep_object().get());
}
void CriticalSectionBeginPhyInstrOperand::FinishInvalidInterfaceEventRecords() {
for (const auto& op_name : interfaces_op_names()) {
size_t index = CHECK_JUST(MapAt(op_name2interface_index_, op_name));
if (!interfaces_valid().at(index)) {
const auto& iter = op_name2end_event_record_->find(op_name);
CHECK(iter != op_name2end_event_record_->end());
iter->second->Init(std::make_shared<NaiveEventRecord>());
}
}
}
void CriticalSectionBeginPhyInstrOperand::Finish() {
for (const auto& pair : *op_name2end_event_record_) {
pair.second->TryInit(std::make_shared<NaiveEventRecord>());
}
}
void InputCriticalSectionBeginPhyInstrOperand::AccessBlobByOpName(uint64_t of_blob_ptr,
const std::string& op_name) {
int64_t i = CHECK_JUST(MapAt(op_name2interface_index_, op_name));
CHECK(interfaces_valid().at(i));
OfBlob* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr);
const auto& eager_blob_object = eager_blob_objects_->at(i);
{
size_t header_size = of_blob->mut_blob()->blob_desc().ByteSizeOfBlobHeader();
CHECK_EQ(header_size, eager_blob_object->shape().NumAxes() * sizeof(int64_t));
std::memcpy(of_blob->mut_blob()->mut_header_ptr(), eager_blob_object->mut_header_ptr(),
header_size);
}
const auto& end_event_record = op_name2end_event_record_->at(op_name);
if (eager_blob_object->dptr() == nullptr) {
end_event_record->Init(std::make_shared<NaiveEventRecord>());
} else {
{
const size_t body_bytes = of_blob->blob().ByteSizeOfBlobBody();
CHECK_EQ(eager_blob_object->ByteSizeOfBlobBody(), body_bytes);
AutoMemcpy(of_blob->stream(), of_blob->mut_blob()->mut_dptr(), eager_blob_object->dptr(),
body_bytes, of_blob->blob().mem_case(), eager_blob_object->mem_case());
}
end_event_record->Init(EpBasedEventRecord::MakeEventRecord(of_blob->stream()));
}
}
void OutputCriticalSectionBeginPhyInstrOperand::AccessBlobByOpName(uint64_t of_blob_ptr,
const std::string& op_name) {
int64_t i = CHECK_JUST(MapAt(op_name2interface_index_, op_name));
CHECK(interfaces_valid().at(i));
OfBlob* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr);
auto& eager_blob_object = eager_blob_objects_->at(i);
of_blob->blob().shape_view().ToShape(eager_blob_object->mut_shape());
const auto& end_event_record = op_name2end_event_record_->at(op_name);
if (eager_blob_object->dptr() == nullptr) {
end_event_record->Init(std::make_shared<NaiveEventRecord>());
} else {
{
const size_t body_bytes = of_blob->blob().ByteSizeOfBlobBody();
CHECK_EQ(eager_blob_object->ByteSizeOfBlobBody(), body_bytes);
AutoMemcpy(of_blob->stream(), eager_blob_object->mut_dptr(), of_blob->blob().dptr(),
body_bytes, eager_blob_object->mem_case(), of_blob->blob().mem_case());
}
end_event_record->Init(EpBasedEventRecord::MakeEventRecord(of_blob->stream()));
}
}
void CriticalSectionEndPhyInstrOperand::ForEachMutMirroredObject(
const std::function<void(vm::MirroredObject* compute)>& DoEach) const {
DoEach(vm_stream_->schedule_local_dep_object().get());
}
} // namespace vm
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EAGER_CRITICAL_SECTION_PHY_INSTR_OPERAND_H_
#define ONEFLOW_CORE_EAGER_CRITICAL_SECTION_PHY_INSTR_OPERAND_H_
#include "oneflow/core/vm/phy_instr_operand.h"
#include "oneflow/core/eager/eager_blob_object.h"
#include "oneflow/core/device/event_record.h"
#include "oneflow/core/framework/nn_graph_if.h"
#include "oneflow/core/common/buffer_manager.h"
namespace oneflow {
namespace one {
using EagerBlobObjectListPtr =
std::shared_ptr<const std::vector<std::shared_ptr<vm::EagerBlobObject>>>;
}
namespace vm {
class Stream;
class CriticalSectionBeginPhyInstrOperand : public PhyInstrOperand {
public:
CriticalSectionBeginPhyInstrOperand(const CriticalSectionBeginPhyInstrOperand&) = delete;
CriticalSectionBeginPhyInstrOperand(CriticalSectionBeginPhyInstrOperand&&) = delete;
CriticalSectionBeginPhyInstrOperand& operator=(const CriticalSectionBeginPhyInstrOperand&) =
delete;
CriticalSectionBeginPhyInstrOperand& operator=(CriticalSectionBeginPhyInstrOperand&&) = delete;
virtual ~CriticalSectionBeginPhyInstrOperand() = default;
explicit CriticalSectionBeginPhyInstrOperand(
const std::shared_ptr<NNGraphIf>& nn_graph,
const one::EagerBlobObjectListPtr& eager_blob_objects,
const std::shared_ptr<HashMap<std::string, std::shared_ptr<SharedEventRecord>>>&
op_name2end_event_record,
vm::Stream* vm_stream)
: nn_graph_(nn_graph),
eager_blob_objects_(eager_blob_objects),
op_name2end_event_record_(op_name2end_event_record),
vm_stream_(vm_stream) {}
const std::shared_ptr<NNGraphIf>& nn_graph() const { return nn_graph_; }
const one::EagerBlobObjectListPtr& eager_blob_objects() const { return eager_blob_objects_; }
void ForEachMirroredObject(const std::function<void(vm::MirroredObject* compute)>&) const;
void ForEachMutMirroredObject(const std::function<void(vm::MirroredObject* compute)>&) const;
virtual const std::vector<std::string>& interfaces_op_names() const = 0;
virtual const std::vector<bool>& interfaces_valid() const = 0;
virtual std::string GetInterfaceBufferName(const std::string& job_name,
const std::string& op_name) const = 0;
virtual std::string GetInterfaceCriticalSectionCallbackBufferName(
const std::string& job_name) const = 0;
virtual std::string GetInterfaceCriticalSectionWaitBufferName(
const std::string& job_name) const = 0;
virtual void AccessBlobByOpName(uint64_t of_blob_ptr, const std::string& op_name) = 0;
void FinishInvalidInterfaceEventRecords();
void Finish();
void ForEachInputEagerBlobObjects(void (*DoEach)(EagerBlobObject*)) const override {
for (const auto& eager_blob_object : *eager_blob_objects_) { DoEach(eager_blob_object.get()); }
}
protected:
std::shared_ptr<NNGraphIf> nn_graph_;
one::EagerBlobObjectListPtr eager_blob_objects_;
std::shared_ptr<HashMap<std::string, std::shared_ptr<SharedEventRecord>>>
op_name2end_event_record_;
HashMap<std::string, size_t> op_name2interface_index_;
vm::Stream* vm_stream_;
};
class InputCriticalSectionBeginPhyInstrOperand final : public CriticalSectionBeginPhyInstrOperand {
public:
InputCriticalSectionBeginPhyInstrOperand(
const std::shared_ptr<NNGraphIf>& nn_graph,
const one::EagerBlobObjectListPtr& eager_blob_objects,
const std::shared_ptr<HashMap<std::string, std::shared_ptr<SharedEventRecord>>>&
op_name2end_event_record,
vm::Stream* vm_stream)
: CriticalSectionBeginPhyInstrOperand(nn_graph, eager_blob_objects, op_name2end_event_record,
vm_stream),
input_dependences_(),
output_dependences_() {
ForEachConstMirroredObject(SetInserter(&input_dependences_));
ForEachMutMirroredObject(SetInserter(&output_dependences_));
ForEachMut2MirroredObject(SetInserter(&output_dependences_));
CHECK_EQ(nn_graph->inputs_op_names().size(), eager_blob_objects->size());
CHECK_EQ(nn_graph->inputs_op_names().size(), nn_graph->inputs_valid().size());
for (int i = 0; i < nn_graph->inputs_op_names().size(); ++i) {
CHECK(op_name2interface_index_.emplace(nn_graph->inputs_op_names().at(i), i).second);
}
}
~InputCriticalSectionBeginPhyInstrOperand() override = default;
const DependenceVector& input_dependences() const override { return input_dependences_; }
const DependenceVector& output_dependences() const override { return output_dependences_; }
// for inputs
void ForEachConstMirroredObject(
const std::function<void(vm::MirroredObject* compute)>& DoEach) const {
ForEachMirroredObject(DoEach);
}
// for outputs
const std::vector<std::string>& interfaces_op_names() const override {
return nn_graph_->inputs_op_names();
}
const std::vector<bool>& interfaces_valid() const override { return nn_graph_->inputs_valid(); }
std::string GetInterfaceBufferName(const std::string& job_name,
const std::string& op_name) const override {
return GetInputBufferName(job_name, op_name);
}
std::string GetInterfaceCriticalSectionCallbackBufferName(
const std::string& job_name) const override {
return GetInputCriticalSectionCallbackBufferName(job_name);
}
std::string GetInterfaceCriticalSectionWaitBufferName(
const std::string& job_name) const override {
return GetInputCriticalSectionWaitBufferName(job_name);
}
void AccessBlobByOpName(uint64_t of_blob_ptr, const std::string& op_name) override;
void ForEachMut2MirroredObject(const std::function<void(vm::MirroredObject* compute)>&) const {}
private:
DependenceVector input_dependences_;
DependenceVector output_dependences_;
};
class OutputCriticalSectionBeginPhyInstrOperand final : public CriticalSectionBeginPhyInstrOperand {
public:
OutputCriticalSectionBeginPhyInstrOperand(
const std::shared_ptr<NNGraphIf>& nn_graph,
const one::EagerBlobObjectListPtr& eager_blob_objects,
const std::shared_ptr<HashMap<std::string, std::shared_ptr<SharedEventRecord>>>&
op_name2end_event_record,
vm::Stream* vm_stream)
: CriticalSectionBeginPhyInstrOperand(nn_graph, eager_blob_objects, op_name2end_event_record,
vm_stream),
input_dependences_(),
output_dependences_() {
ForEachConstMirroredObject(SetInserter(&input_dependences_));
ForEachMutMirroredObject(SetInserter(&output_dependences_));
ForEachMut2MirroredObject(SetInserter(&output_dependences_));
CHECK_EQ(nn_graph->outputs_op_names().size(), eager_blob_objects->size());
CHECK_EQ(nn_graph->outputs_op_names().size(), nn_graph->outputs_valid().size());
for (int i = 0; i < nn_graph->outputs_op_names().size(); ++i) {
CHECK(op_name2interface_index_.emplace(nn_graph->outputs_op_names().at(i), i).second);
}
}
~OutputCriticalSectionBeginPhyInstrOperand() override = default;
const DependenceVector& input_dependences() const override { return input_dependences_; }
const DependenceVector& output_dependences() const override { return output_dependences_; }
// for inputs
void ForEachConstMirroredObject(const std::function<void(vm::MirroredObject* compute)>&) const {}
// for outputs
void ForEachMut2MirroredObject(
const std::function<void(vm::MirroredObject* compute)>& DoEach) const {
ForEachMirroredObject(DoEach);
}
const std::vector<std::string>& interfaces_op_names() const override {
return nn_graph_->outputs_op_names();
}
const std::vector<bool>& interfaces_valid() const override { return nn_graph_->outputs_valid(); }
std::string GetInterfaceBufferName(const std::string& job_name,
const std::string& op_name) const override {
return GetOutputBufferName(job_name, op_name);
}
std::string GetInterfaceCriticalSectionCallbackBufferName(
const std::string& job_name) const override {
return GetOutputCriticalSectionCallbackBufferName(job_name);
}
std::string GetInterfaceCriticalSectionWaitBufferName(
const std::string& job_name) const override {
return GetOutputCriticalSectionWaitBufferName(job_name);
}
void AccessBlobByOpName(uint64_t of_blob_ptr, const std::string& op_name) override;
private:
DependenceVector input_dependences_;
DependenceVector output_dependences_;
};
class CriticalSectionEndPhyInstrOperand : public PhyInstrOperand {
public:
CriticalSectionEndPhyInstrOperand(const std::shared_ptr<EagerBlobObject>& eager_blob_object,
const std::shared_ptr<SharedEventRecord>& event_record,
vm::Stream* vm_stream)
: eager_blob_object_(eager_blob_object), event_record_(event_record), vm_stream_(vm_stream) {}
virtual ~CriticalSectionEndPhyInstrOperand() = default;
const std::shared_ptr<SharedEventRecord>& event_record() const { return event_record_; }
void ForEachMirroredObject(const std::function<void(vm::MirroredObject* compute)>&) const;
void ForEachMutMirroredObject(const std::function<void(vm::MirroredObject* compute)>&) const;
void ForEachInputEagerBlobObjects(void (*DoEach)(EagerBlobObject*)) const override {
DoEach(eager_blob_object_.get());
}
private:
std::shared_ptr<EagerBlobObject> eager_blob_object_;
std::shared_ptr<SharedEventRecord> event_record_;
vm::Stream* vm_stream_;
};
class InputCriticalSecondEndPhyInstrOperand final : public CriticalSectionEndPhyInstrOperand {
public:
InputCriticalSecondEndPhyInstrOperand(const std::shared_ptr<EagerBlobObject>& eager_blob_object,
const std::shared_ptr<SharedEventRecord>& event_record,
vm::Stream* vm_stream)
: CriticalSectionEndPhyInstrOperand(eager_blob_object, event_record, vm_stream),
input_dependences_(),
output_dependences_() {
ForEachConstMirroredObject(SetInserter(&input_dependences_));
ForEachMutMirroredObject(SetInserter(&output_dependences_));
ForEachMut2MirroredObject(SetInserter(&output_dependences_));
}
~InputCriticalSecondEndPhyInstrOperand() override = default;
const DependenceVector& input_dependences() const override { return input_dependences_; }
const DependenceVector& output_dependences() const override { return output_dependences_; }
void ForEachConstMirroredObject(
const std::function<void(vm::MirroredObject* compute)>& DoEach) const {
ForEachMirroredObject(DoEach);
}
void ForEachMut2MirroredObject(const std::function<void(vm::MirroredObject* compute)>&) const {}
private:
DependenceVector input_dependences_;
DependenceVector output_dependences_;
};
class OutputCriticalSecondEndPhyInstrOperand final : public CriticalSectionEndPhyInstrOperand {
public:
OutputCriticalSecondEndPhyInstrOperand(const std::shared_ptr<EagerBlobObject>& eager_blob_object,
const std::shared_ptr<SharedEventRecord>& event_record,
vm::Stream* vm_stream)
: CriticalSectionEndPhyInstrOperand(eager_blob_object, event_record, vm_stream),
input_dependences_(),
output_dependences_() {
ForEachConstMirroredObject(SetInserter(&input_dependences_));
ForEachMutMirroredObject(SetInserter(&output_dependences_));
ForEachMut2MirroredObject(SetInserter(&output_dependences_));
}
~OutputCriticalSecondEndPhyInstrOperand() override = default;
const DependenceVector& input_dependences() const override { return input_dependences_; }
const DependenceVector& output_dependences() const override { return output_dependences_; }
// for inputs
void ForEachConstMirroredObject(const std::function<void(vm::MirroredObject* compute)>&) const {}
// for outputs
void ForEachMut2MirroredObject(
const std::function<void(vm::MirroredObject* compute)>& DoEach) const {
ForEachMirroredObject(DoEach);
}
private:
DependenceVector input_dependences_;
DependenceVector output_dependences_;
};
} // namespace vm
} // namespace oneflow
#endif // ONEFLOW_CORE_EAGER_CRITICAL_SECTION_PHY_INSTR_OPERAND_H_
...@@ -18,53 +18,77 @@ limitations under the License. ...@@ -18,53 +18,77 @@ limitations under the License.
#include "oneflow/core/framework/to_string.h" #include "oneflow/core/framework/to_string.h"
#include "oneflow/core/framework/shut_down_util.h" #include "oneflow/core/framework/shut_down_util.h"
#include "oneflow/core/common/shape_vec.h" #include "oneflow/core/common/shape_vec.h"
#include "oneflow/core/common/tensor_meta.h"
namespace oneflow { namespace oneflow {
namespace vm { namespace vm {
EagerBlobObject::EagerBlobObject(const std::shared_ptr<MemoryCase>& mem_case, EagerBlobObject::EagerBlobObject(
const std::shared_ptr<Shape>& shape, const std::shared_ptr<MemoryCase>& mem_case,
const std::shared_ptr<Stride>& stride, DataType data_type, const Symbol<one::LocalTensorMeta>& static_local_tensor_meta,
const std::shared_ptr<TensorStorage>& tensor_storage, const std::shared_ptr<const one::MutLocalTensorMeta>& dynamic_local_tensor_meta,
const intrusive::shared_ptr<LocalDepObject>& dep_object) DataType data_type, const std::shared_ptr<TensorStorage>& tensor_storage,
const intrusive::shared_ptr<LocalDepObject>& dep_object)
: is_dynamic_(false), : is_dynamic_(false),
mem_case_(mem_case), mem_case_(mem_case),
data_type_(data_type), data_type_(data_type),
shape_(shape),
stride_(stride),
storage_offset_(0), storage_offset_(0),
tensor_storage_(tensor_storage), tensor_storage_(tensor_storage),
mem_ptr_for_allocation_compuation_pipelining_(nullptr),
inited_mem_ptr_for_allocation_compuation_pipelining_(false),
is_non_pod_object_placement_newed_(false),
is_shape_synced_(true),
compute_local_dep_object_(dep_object), compute_local_dep_object_(dep_object),
blob_desc_(shape, stride, data_type) { static_local_tensor_meta_(static_local_tensor_meta),
CHECK(static_cast<bool>(shape)); dynamic_local_tensor_meta_(dynamic_local_tensor_meta) {
CHECK(static_cast<bool>(stride));
CHECK(static_cast<bool>(tensor_storage)); CHECK(static_cast<bool>(tensor_storage));
} }
Blob* EagerBlobObject::blob() { // user_op::TensorDesc overrides
if (!blob_) { const Shape& EagerBlobObject::shape() const {
blob_.reset(new Blob(*mem_case_, &blob_desc_, mut_header_ptr(), mut_dptr<char>())); if (dynamic_local_tensor_meta_) {
return dynamic_local_tensor_meta_->shape();
} else {
return static_local_tensor_meta_->shape();
}
}
const Stride& EagerBlobObject::stride() const {
if (dynamic_local_tensor_meta_) {
return dynamic_local_tensor_meta_->stride();
} else {
return static_local_tensor_meta_->stride();
} }
return blob_.get();
} }
void EagerBlobObject::set_storage_offset(const int64_t offset) { storage_offset_ = offset; } void EagerBlobObject::set_shape(const Shape& shape) {
CHECK(dynamic_local_tensor_meta_);
std::const_pointer_cast<one::MutLocalTensorMeta>(dynamic_local_tensor_meta_)->set_shape(shape);
}
void EagerBlobObject::set_stride(const Stride& stride) {
CHECK(dynamic_local_tensor_meta_);
std::const_pointer_cast<one::MutLocalTensorMeta>(dynamic_local_tensor_meta_)->set_stride(stride);
}
void EagerBlobObject::TryInitNonPODTypeEagerBlobObjectIfNeed() { MutShapeView EagerBlobObject::mut_shape_view() {
if (!IsPODDataType(data_type())) { CHECK(dynamic_local_tensor_meta_);
if (!is_non_pod_object_placement_newed_) { return *const_cast<Shape*>(dynamic_local_tensor_meta_->shape_ptr().get());
InitNonPODTypeEagerBlobObjectIfNeed(tensor_storage_->non_pod_allocator(), this); }
is_non_pod_object_placement_newed_ = true;
} std::shared_ptr<const Shape> EagerBlobObject::shape_ptr() const {
if (dynamic_local_tensor_meta_) {
return dynamic_local_tensor_meta_->shape_ptr();
} else {
return static_local_tensor_meta_->shape_ptr();
}
}
std::shared_ptr<const Stride> EagerBlobObject::stride_ptr() const {
if (dynamic_local_tensor_meta_) {
return dynamic_local_tensor_meta_->stride_ptr();
} else {
return static_local_tensor_meta_->stride_ptr();
} }
} }
Maybe<void> EagerBlobObject::TryAllocateBlobBodyMemory(DeviceCtx* device_ctx) { void EagerBlobObject::set_storage_offset(const int64_t offset) { storage_offset_ = offset; }
vm::Allocator* allocator = device_ctx->mut_allocator();
Maybe<bool> EagerBlobObject::TryAllocateBlobBodyMemory(vm::Allocator* allocator) {
size_t required_body_bytes = AlignedByteSizeOfBlobBody(); size_t required_body_bytes = AlignedByteSizeOfBlobBody();
if (required_body_bytes == 0) { if (required_body_bytes == 0) {
CHECK_ISNULL_OR_RETURN(tensor_storage_->blob_dptr()); CHECK_ISNULL_OR_RETURN(tensor_storage_->blob_dptr());
...@@ -81,10 +105,10 @@ Maybe<void> EagerBlobObject::TryAllocateBlobBodyMemory(DeviceCtx* device_ctx) { ...@@ -81,10 +105,10 @@ Maybe<void> EagerBlobObject::TryAllocateBlobBodyMemory(DeviceCtx* device_ctx) {
}; };
tensor_storage_->set_blob_dptr(std::unique_ptr<char, std::function<void(char*)>>(dptr, Free), tensor_storage_->set_blob_dptr(std::unique_ptr<char, std::function<void(char*)>>(dptr, Free),
required_body_bytes); required_body_bytes);
InitMemPtrForAllocationComputationPipelining(); InitNonPODTypeEagerBlobObjectIfNeed(tensor_storage_->non_pod_allocator(), this);
return true;
} }
InitOrCheckMemPtrForAllocationComputationPipelining(); return false;
return Maybe<void>::Ok();
} }
} // namespace vm } // namespace vm
......
...@@ -18,6 +18,7 @@ limitations under the License. ...@@ -18,6 +18,7 @@ limitations under the License.
#include "oneflow/core/common/maybe.h" #include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/optional.h" #include "oneflow/core/common/optional.h"
#include "oneflow/core/common/op_args_reserved_size.h"
#include "oneflow/core/eager/local_dep_object.h" #include "oneflow/core/eager/local_dep_object.h"
#include "oneflow/core/device/device_context.h" #include "oneflow/core/device/device_context.h"
#include "oneflow/core/memory/memory_allocator.h" #include "oneflow/core/memory/memory_allocator.h"
...@@ -25,24 +26,34 @@ limitations under the License. ...@@ -25,24 +26,34 @@ limitations under the License.
#include "oneflow/core/framework/stream.h" #include "oneflow/core/framework/stream.h"
#include "oneflow/core/framework/tensor_methods.h" #include "oneflow/core/framework/tensor_methods.h"
#include "oneflow/core/framework/user_op_tensor.h" #include "oneflow/core/framework/user_op_tensor.h"
#include "oneflow/core/framework/tensor_desc.h" #include "oneflow/core/common/tensor_desc.h"
#include "oneflow/core/register/blob.h" #include "oneflow/core/register/blob.h"
namespace oneflow { namespace oneflow {
namespace one {
class LocalTensorMeta;
class MutLocalTensorMeta;
} // namespace one
namespace vm { namespace vm {
class TensorStorage { class TensorStorage {
public: public:
TensorStorage() TensorStorage()
: non_pod_allocator_(std::make_unique<MemoryAllocator>()), : blob_bytes_(0),
non_pod_allocator_(std::make_unique<MemoryAllocator>()),
producer_stream_(NullOpt), producer_stream_(NullOpt),
last_used_stream_(NullOpt) {} last_used_stream_(NullOpt) {}
~TensorStorage() { virtual ~TensorStorage() {
for (const auto& hook : storage_delete_hooks_) { hook(); } for (const auto& hook : storage_delete_hooks_) { hook(); }
} }
virtual bool is_allocated_in_vm() const = 0;
size_t blob_bytes() const { return blob_bytes_; } size_t blob_bytes() const { return blob_bytes_; }
char* blob_dptr() { return blob_dptr_.get(); } char* blob_dptr() { return blob_dptr_.get(); }
...@@ -84,58 +95,77 @@ class TensorStorage { ...@@ -84,58 +95,77 @@ class TensorStorage {
std::vector<std::function<void()>> storage_delete_hooks_; std::vector<std::function<void()>> storage_delete_hooks_;
}; };
class InsideVmTensorStorage : public TensorStorage {
public:
InsideVmTensorStorage() = default;
~InsideVmTensorStorage() = default;
bool is_allocated_in_vm() const override { return true; }
};
class OutsideVmTensorStorage : public TensorStorage {
public:
OutsideVmTensorStorage() = default;
~OutsideVmTensorStorage() = default;
bool is_allocated_in_vm() const override { return false; }
};
class EagerBlobObject final : public user_op::Tensor, class EagerBlobObject final : public user_op::Tensor,
public user_op::TensorDesc, public user_op::TensorDesc,
public std::enable_shared_from_this<EagerBlobObject> { public std::enable_shared_from_this<EagerBlobObject> {
public: public:
EagerBlobObject(const EagerBlobObject&) = delete; EagerBlobObject(const EagerBlobObject&) = delete;
EagerBlobObject(EagerBlobObject&&) = delete; EagerBlobObject(EagerBlobObject&&) = delete;
EagerBlobObject(const std::shared_ptr<MemoryCase>& mem_case, const std::shared_ptr<Shape>& shape, EagerBlobObject(const std::shared_ptr<MemoryCase>& mem_case,
const std::shared_ptr<Stride>& stride, DataType data_type, const Symbol<one::LocalTensorMeta>& static_local_tensor_meta,
const std::shared_ptr<TensorStorage>& tensor_storage) const std::shared_ptr<const one::MutLocalTensorMeta>& dynamic_local_tensor_meta,
: EagerBlobObject(mem_case, shape, stride, data_type, tensor_storage, DataType data_type, const std::shared_ptr<TensorStorage>& tensor_storage)
intrusive::shared_ptr<LocalDepObject>()) {} : EagerBlobObject(mem_case, static_local_tensor_meta, dynamic_local_tensor_meta, data_type,
EagerBlobObject(const std::shared_ptr<MemoryCase>& mem_case, const std::shared_ptr<Shape>& shape, tensor_storage, intrusive::shared_ptr<LocalDepObject>()) {}
const std::shared_ptr<Stride>& stride, DataType data_type, EagerBlobObject(const std::shared_ptr<MemoryCase>& mem_case,
const std::shared_ptr<TensorStorage>& tensor_storage, const Symbol<one::LocalTensorMeta>& static_local_tensor_meta,
const std::shared_ptr<const one::MutLocalTensorMeta>& dynamic_local_tensor_meta,
DataType data_type, const std::shared_ptr<TensorStorage>& tensor_storage,
const intrusive::shared_ptr<LocalDepObject>& dep_object); const intrusive::shared_ptr<LocalDepObject>& dep_object);
~EagerBlobObject() { tensor_storage_.reset(); } ~EagerBlobObject() { tensor_storage_.reset(); }
const std::shared_ptr<const one::MutLocalTensorMeta>& mut_tensor_meta() {
return dynamic_local_tensor_meta_;
}
// Getters
const Symbol<one::LocalTensorMeta>& tensor_meta() const { return static_local_tensor_meta_; }
// user_op::TensorDesc overrides // user_op::TensorDesc overrides
const Shape& shape() const override { return *shape_; } const Shape& shape() const override;
Shape* mut_shape() override { return shape_.get(); } const Stride& stride() const override;
const Stride& stride() const override { return *stride_; }
Stride* mut_stride() override { return stride_.get(); }
DataType data_type() const override { return data_type_; } DataType data_type() const override { return data_type_; }
DataType* mut_data_type() override { return &data_type_; }
bool is_dynamic() const override { return is_dynamic_; } bool is_dynamic() const override { return is_dynamic_; }
bool* mut_is_dynamic() override { return &is_dynamic_; }
void set_shape(const Shape& shape) override;
void set_stride(const Stride& stride) override;
void set_data_type(DataType data_type) override { data_type_ = data_type; }
void set_is_dynamic(bool is_dynamic) override { is_dynamic_ = is_dynamic; } void set_is_dynamic(bool is_dynamic) override { is_dynamic_ = is_dynamic; }
// user_op::Tensor overrides // user_op::Tensor overrides
ShapeView shape_view() const override { return *shape_; } ShapeView shape_view() const override { return shape(); }
MutShapeView mut_shape_view() override { return *shape_; } MutShapeView mut_shape_view() override;
const MemoryCase& mem_case() const override { return *mem_case_; } const MemoryCase& mem_case() const override { return *mem_case_; }
const void* raw_dptr() const override { const void* raw_dptr() const override {
CHECK(inited_mem_ptr_for_allocation_compuation_pipelining_) char* ptr = tensor_storage_->blob_dptr();
<< "mem_ptr_for_allocation_compuation_pipelining_ not initialized. Please check if there " if (tensor_storage_->blob_bytes() > 0) { CHECK_NOTNULL(ptr); }
"are any EagerBlobObjects created outside vm"; return ptr + storage_offset_ * GetSizeOfDataType(data_type_);
return mem_ptr_for_allocation_compuation_pipelining_
+ storage_offset_ * GetSizeOfDataType(data_type_);
} }
void* mut_raw_dptr() override { return const_cast<void*>(raw_dptr()); } void* mut_raw_dptr() override { return const_cast<void*>(raw_dptr()); }
void set_storage_offset(const int64_t offset); void set_storage_offset(const int64_t offset);
[[deprecated("\"Blob\" will be removed in eager. Please avoid to use this method whenever " // Returns true if allocate successfully.
"possible. Almost all methods of `Blob` are also in `EagerBlobObject`.")]] Blob* Maybe<bool> TryAllocateBlobBodyMemory(vm::Allocator* allocator);
blob();
Maybe<void> TryAllocateBlobBodyMemory(DeviceCtx* device_ctx);
Maybe<void> DeallocateBlobDataPtr() { Maybe<void> DeallocateBlobDataPtr() {
tensor_storage_->Release(); tensor_storage_->Release();
tensor_storage_.reset(new TensorStorage); tensor_storage_.reset(new InsideVmTensorStorage());
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
void RegisterStorageDeleteHook(const std::function<void()>& hook) { void RegisterStorageDeleteHook(const std::function<void()>& hook) {
...@@ -149,10 +179,6 @@ class EagerBlobObject final : public user_op::Tensor, ...@@ -149,10 +179,6 @@ class EagerBlobObject final : public user_op::Tensor,
std::shared_ptr<TensorStorage>& tensor_storage() { return tensor_storage_; } std::shared_ptr<TensorStorage>& tensor_storage() { return tensor_storage_; }
bool is_shape_synced() const { return is_shape_synced_; }
void set_is_shape_synced(bool val) { is_shape_synced_ = val; }
const Optional<Symbol<::oneflow::Stream>>& producer_stream() const { const Optional<Symbol<::oneflow::Stream>>& producer_stream() const {
return tensor_storage_->producer_stream(); return tensor_storage_->producer_stream();
} }
...@@ -167,10 +193,10 @@ class EagerBlobObject final : public user_op::Tensor, ...@@ -167,10 +193,10 @@ class EagerBlobObject final : public user_op::Tensor,
tensor_storage_->set_last_used_stream(last_used_stream); tensor_storage_->set_last_used_stream(last_used_stream);
} }
std::shared_ptr<const Shape> shape_ptr() const { return shape_; } std::shared_ptr<const Shape> shape_ptr() const;
std::shared_ptr<const Stride> stride_ptr() const { return stride_; } std::shared_ptr<const Stride> stride_ptr() const;
size_t ByteSizeOfBlobBody() const { return shape_->elem_cnt() * GetSizeOfDataType(data_type_); } size_t ByteSizeOfBlobBody() const { return shape().elem_cnt() * GetSizeOfDataType(data_type_); }
size_t AlignedByteSizeOfBlobBody() const { size_t AlignedByteSizeOfBlobBody() const {
return RoundUp(ByteSizeOfBlobBody(), kBlobBodyAlignSize); return RoundUp(ByteSizeOfBlobBody(), kBlobBodyAlignSize);
} }
...@@ -179,52 +205,28 @@ class EagerBlobObject final : public user_op::Tensor, ...@@ -179,52 +205,28 @@ class EagerBlobObject final : public user_op::Tensor,
return RoundUp(ByteSizeOfBlobHeader(), kBlobHeaderAlignSize); return RoundUp(ByteSizeOfBlobHeader(), kBlobHeaderAlignSize);
} }
const char* header_ptr() const { return reinterpret_cast<const char*>(shape_->dim_vec().data()); } const char* header_ptr() const { return reinterpret_cast<const char*>(shape().dim_vec().data()); }
char* mut_header_ptr() { return reinterpret_cast<char*>(shape_->dim_vec().data()); } char* mut_header_ptr() {
return reinterpret_cast<char*>(const_cast<int64_t*>(shape().dim_vec().data()));
void InitOrCheckMemPtrForAllocationComputationPipelining() {
auto* ptr = tensor_storage_->blob_dptr();
if (inited_mem_ptr_for_allocation_compuation_pipelining_) {
CHECK_EQ(mem_ptr_for_allocation_compuation_pipelining_, ptr);
} else {
mem_ptr_for_allocation_compuation_pipelining_ = ptr;
inited_mem_ptr_for_allocation_compuation_pipelining_ = true;
}
} }
void TryInitNonPODTypeEagerBlobObjectIfNeed();
private: private:
void InitMemPtrForAllocationComputationPipelining() {
auto* ptr = tensor_storage_->blob_dptr();
CHECK(!inited_mem_ptr_for_allocation_compuation_pipelining_)
<< "mem_ptr_for_allocation_compuation_pipelining_ has been initialized.";
mem_ptr_for_allocation_compuation_pipelining_ = ptr;
inited_mem_ptr_for_allocation_compuation_pipelining_ = true;
}
bool is_dynamic_; bool is_dynamic_;
std::shared_ptr<MemoryCase> mem_case_; std::shared_ptr<MemoryCase> mem_case_;
DataType data_type_; DataType data_type_;
std::shared_ptr<Shape> shape_;
std::shared_ptr<Stride> stride_;
int64_t storage_offset_; int64_t storage_offset_;
std::shared_ptr<TensorStorage> tensor_storage_; std::shared_ptr<TensorStorage> tensor_storage_;
// For allocation-computation pipeline, the value of mem_ptr_for_allocation_compuation_pipelining_
// are kept even after tensor_storage_.reset().
char* mem_ptr_for_allocation_compuation_pipelining_;
bool inited_mem_ptr_for_allocation_compuation_pipelining_;
bool is_non_pod_object_placement_newed_;
std::atomic<bool> is_shape_synced_;
bool pin_memory_;
intrusive::shared_ptr<LocalDepObject> compute_local_dep_object_; intrusive::shared_ptr<LocalDepObject> compute_local_dep_object_;
// NOTE: Will be removed soon. Avoid to use it whenever possible. Symbol<one::LocalTensorMeta> static_local_tensor_meta_;
BlobDesc blob_desc_; std::shared_ptr<const one::MutLocalTensorMeta> dynamic_local_tensor_meta_;
std::unique_ptr<Blob> blob_;
}; };
using EagerBlobObjectList = small_vector<std::shared_ptr<vm::EagerBlobObject>, kOpArgsReservedSize>;
using EagerBlobObjectListPtr = std::shared_ptr<const EagerBlobObjectList>;
} // namespace vm } // namespace vm
} // namespace oneflow } // namespace oneflow
#endif // ONEFLOW_CORE_EAGER_EAGER_BLOB_OBJECT_H_ #endif // ONEFLOW_CORE_EAGER_EAGER_BLOB_OBJECT_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EAGER_LAZY_JOB_INSTRUCTION_TYPE_H_
#define ONEFLOW_CORE_EAGER_LAZY_JOB_INSTRUCTION_TYPE_H_
#include "oneflow/core/vm/lazy_job_device_context.h"
#include "oneflow/core/eager/lazy_job_phy_instr_operand.h"
#include "oneflow/core/framework/nn_graph_if.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/common/of_unused.h"
#include "oneflow/core/vm/instruction.h"
#include "oneflow/core/vm/instruction_type.h"
#include "oneflow/core/job/job_instance.h"
#include "oneflow/core/common/buffer_manager.h"
#include "oneflow/core/common/singleton.h"
#include "oneflow/core/vm/stream.h"
#include "oneflow/core/vm/thread_ctx.h"
#include "oneflow/core/register/ofblob.h"
#include "oneflow/core/vm/naive_instruction_status_querier.h"
#include "oneflow/core/profiler/profiler.h"
#include "oneflow/core/kernel/kernel_util.h"
namespace oneflow {
class LazyJobInstance final : public JobInstance {
public:
LazyJobInstance(const LazyJobInstance&) = delete;
LazyJobInstance(LazyJobInstance&&) = delete;
~LazyJobInstance() override = default;
LazyJobInstance(const std::string& job_name, const std::function<void()>& finish_cb)
: job_name_(job_name), finish_cb_(finish_cb) {}
std::string job_name() const override { return job_name_; }
void Finish() const override { finish_cb_(); }
std::string sole_input_op_name_in_user_job() const override {
UNIMPLEMENTED();
return std::string();
}
std::string sole_output_op_name_in_user_job() const override {
UNIMPLEMENTED();
return std::string();
}
void PushBlob(uint64_t ofblob_ptr) const override { UNIMPLEMENTED(); }
void PullBlob(uint64_t ofblob_ptr) const override { UNIMPLEMENTED(); }
private:
const std::string job_name_;
const std::function<void()> finish_cb_;
};
namespace vm {
class LaunchLazyJobInstructionType final : public InstructionType { // NOLINT
public:
LaunchLazyJobInstructionType(const LaunchLazyJobInstructionType&) = delete;
LaunchLazyJobInstructionType(LaunchLazyJobInstructionType&&) = delete;
LaunchLazyJobInstructionType() = default;
~LaunchLazyJobInstructionType() = default;
std::string DebugName(const vm::Instruction&) const override { return "LaunchLazyJob"; }
Maybe<void> Prepare(vm::Instruction* instruction) const override { return Maybe<void>::Ok(); }
void Compute(vm::Instruction* instruction) const override {
const auto& cur_nn_graph = GetCurNNGraph(instruction);
auto* device_ctx = GetLazyJobDeviceCtx(instruction);
static thread_local int64_t run_id = 0;
{
OF_PROFILER_RANGE_GUARD("WaitUntilQueueEmptyIfFrontNNGraphNotEquals");
device_ctx->WaitUntilQueueEmptyIfFrontNNGraphNotEquals(cur_nn_graph);
}
{
OF_PROFILER_RANGE_GUARD("Send all buffers to BufferMgr");
const auto& job_instance = MakeJobInstance(instruction);
const auto& job_name = job_instance->job_name();
auto* buffer_mgr = Singleton<BufferMgr<std::shared_ptr<JobInstance>>>::Get();
buffer_mgr->Get(GetCallbackNotifierBufferName(job_name))->Push(job_instance);
buffer_mgr->Get(GetSourceTickBufferName(job_name))->Push(job_instance);
}
OF_UNUSED(run_id); // disable compiler warning.
OF_PROFILER_RANGE_GUARD("EnqueueNNGraph");
device_ctx->EnqueueNNGraph(cur_nn_graph);
}
private:
LazyJobDeviceCtx* GetLazyJobDeviceCtx(Instruction* instruction) const {
auto* stream = instruction->mut_stream();
auto* device_ctx = dynamic_cast<LazyJobDeviceCtx*>(stream->device_ctx().get());
CHECK_NOTNULL(device_ctx);
return device_ctx;
}
std::shared_ptr<NNGraphIf> GetCurNNGraph(Instruction* instruction) const {
const auto* ptr = instruction->phy_instr_operand().get();
const auto* phy_instr_operand = dynamic_cast<const LaunchLazyJobPhyInstrOperand*>(ptr);
CHECK_NOTNULL(phy_instr_operand);
return phy_instr_operand->nn_graph();
}
std::shared_ptr<LazyJobInstance> MakeJobInstance(Instruction* instruction) const {
const auto* ptr = instruction->phy_instr_operand().get();
const auto* phy_instr_operand = dynamic_cast<const LaunchLazyJobPhyInstrOperand*>(ptr);
CHECK_NOTNULL(phy_instr_operand);
const auto& nn_graph = phy_instr_operand->nn_graph();
const auto& FinishCb = [this, instruction]() {
auto* device_ctx = GetLazyJobDeviceCtx(instruction);
device_ctx->DequeueNNGraph();
auto* status_buffer = instruction->mut_status_buffer();
NaiveInstrStatusQuerier::MutCast(status_buffer->mut_buffer())->set_done();
};
return std::make_shared<LazyJobInstance>(nn_graph->job_name(), FinishCb);
}
};
} // namespace vm
} // namespace oneflow
#endif // ONEFLOW_CORE_EAGER_LAZY_JOB_INSTRUCTION_TYPE_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/eager/lazy_job_phy_instr_operand.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/stream.h"
#include "oneflow/core/vm/virtual_machine.h"
namespace oneflow {
namespace vm {
void LaunchLazyJobPhyInstrOperand::ForEachMutMirroredObject(
const std::function<void(vm::MirroredObject* compute)>& DoEach) const {
for (const auto& eager_blob_object : *param_blob_objects_) {
DoEach(CHECK_JUST(eager_blob_object->compute_local_dep_object()));
}
DoEach(CHECK_JUST(SingletonMaybe<VirtualMachine>())
->FindOrCreateTransportLocalDepObject()
.Mutable());
}
} // namespace vm
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EAGER_LAZY_JOB_PHY_INSTR_OPERAND_H_
#define ONEFLOW_CORE_EAGER_LAZY_JOB_PHY_INSTR_OPERAND_H_
#include "oneflow/core/eager/eager_blob_object.h"
#include "oneflow/core/eager/local_dep_object.h"
#include "oneflow/core/device/event_record.h"
#include "oneflow/core/eager/critical_section_phy_instr_operand.h"
#include "oneflow/core/framework/nn_graph_if.h"
#include "oneflow/core/common/notifier.h"
namespace oneflow {
namespace one {
using EagerBlobObjectListPtr =
std::shared_ptr<const std::vector<std::shared_ptr<vm::EagerBlobObject>>>;
}
namespace vm {
class LaunchLazyJobPhyInstrOperand final : public PhyInstrOperand {
public:
LaunchLazyJobPhyInstrOperand(const LaunchLazyJobPhyInstrOperand&) = delete;
LaunchLazyJobPhyInstrOperand(LaunchLazyJobPhyInstrOperand&&) = delete;
~LaunchLazyJobPhyInstrOperand() override = default;
LaunchLazyJobPhyInstrOperand(const std::shared_ptr<NNGraphIf>& nn_graph,
const one::EagerBlobObjectListPtr& param_blob_objects)
: nn_graph_(nn_graph),
param_blob_objects_(param_blob_objects),
input_dependences_(),
output_dependences_() {
ForEachConstMirroredObject(SetInserter(&input_dependences_));
ForEachMutMirroredObject(SetInserter(&output_dependences_));
ForEachMut2MirroredObject(SetInserter(&output_dependences_));
stream_sequential_dependence_ = nullptr;
}
const std::shared_ptr<NNGraphIf>& nn_graph() const { return nn_graph_; }
const DependenceVector& input_dependences() const override { return input_dependences_; }
const DependenceVector& output_dependences() const override { return output_dependences_; }
void ForEachConstMirroredObject(const std::function<void(vm::MirroredObject* compute)>&) const {}
void ForEachMutMirroredObject(const std::function<void(vm::MirroredObject* compute)>&) const;
void ForEachMut2MirroredObject(const std::function<void(vm::MirroredObject* compute)>&) const {}
void ForEachInputEagerBlobObjects(void (*DoEach)(EagerBlobObject*)) const override {
for (const auto& eager_blob_object : *param_blob_objects_) { DoEach(eager_blob_object.get()); }
}
private:
std::shared_ptr<NNGraphIf> nn_graph_;
one::EagerBlobObjectListPtr param_blob_objects_;
DependenceVector input_dependences_;
DependenceVector output_dependences_;
};
} // namespace vm
} // namespace oneflow
#endif // ONEFLOW_CORE_EAGER_LAZY_JOB_PHY_INSTR_OPERAND_H_
...@@ -20,12 +20,16 @@ limitations under the License. ...@@ -20,12 +20,16 @@ limitations under the License.
#include "oneflow/core/vm/vm_object.h" #include "oneflow/core/vm/vm_object.h"
#include "oneflow/core/common/maybe.h" #include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/symbol.h" #include "oneflow/core/common/symbol.h"
#include "oneflow/core/common/small_vector.h"
#include "oneflow/core/common/op_args_reserved_size.h"
#include "oneflow/core/framework/device.h" #include "oneflow/core/framework/device.h"
namespace oneflow { namespace oneflow {
// LocalDepObject helps VirtualMachineEngine building instruction edges // LocalDepObject helps VirtualMachineEngine building instruction edges
using LocalDepObject = vm::MirroredObject; using LocalDepObject = vm::Dependence;
using DependenceVector = small_vector<LocalDepObject*, kOpArgsReservedSize>;
intrusive::shared_ptr<LocalDepObject> NewLocalDepObject(); intrusive::shared_ptr<LocalDepObject> NewLocalDepObject();
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/device_type.pb.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/protobuf.h"
#ifdef WITH_ROCM
#include "oneflow/core/ep/rocm/cuda_stream.h"
#else
#include "oneflow/core/ep/cuda/cuda_stream.h"
#endif
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/eager/eager_blob_object.h"
#include "oneflow/core/vm/stream.h"
#include "oneflow/core/vm/allocator.h"
#include "oneflow/core/vm/thread_ctx.h"
#include "oneflow/core/eager/op_call_instruction_type.h"
#include "oneflow/core/eager/op_call_phy_instr_operand.h"
#include "oneflow/core/vm/instruction.h"
#include "oneflow/core/vm/instruction_type.h"
#include "oneflow/core/framework/user_op_registry_manager.h"
#include "oneflow/core/job/foreign_callback.h"
#include "oneflow/core/register/ofblob.h"
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/operator/op_conf_symbol.h"
#include "oneflow/user/kernels/stateful_opkernel.h"
#include "oneflow/core/profiler/profiler.h"
#include "oneflow/core/profiler/profile_manager.h"
#include "oneflow/core/profiler/event_recorder.h"
#include "oneflow/core/common/cpp_attribute.h"
namespace oneflow {
namespace vm {
struct OpCallInstructionUtil final {
static inline Maybe<void> Prepare(const vm::Instruction& instruction) {
auto* operand = GetCallPhyInstrOperand(instruction);
DeviceCtx* device_ctx = instruction.stream().device_ctx().get();
JUST(AllocateOutputBlobsMemory(operand, device_ctx));
if (unlikely(operand->need_temp_storage())) {
InferTempStorageSize(operand);
JUST(TryAllocateTempStorage(operand, device_ctx));
// Since memory block is cached in allocator, it's safe to deallocate tmp buffer before
// kernel executed.
DeallocateTempStorage(operand, device_ctx);
}
return Maybe<void>::Ok();
}
static inline void Compute(const vm::Instruction& instruction) {
auto* operand = GetCallPhyInstrOperand(instruction);
DeviceCtx* device_ctx = instruction.stream().device_ctx().get();
if (!operand->is_all_outputs_pod()) {
for (const auto& blob_object : *operand->outputs()) {
blob_object->TryInitNonPODTypeEagerBlobObjectIfNeed();
}
}
user_op::OpKernelState* state = nullptr;
user_op::OpKernelCache* cache = nullptr;
if (operand->user_opkernel()->has_state_or_cache()) {
TryInitOpKernelStateAndCache(operand, device_ctx, &state, &cache);
}
OpKernelCompute(operand, device_ctx, state, cache);
}
static inline OpCallPhyInstrOperand* GetCallPhyInstrOperand(const vm::Instruction& instruction) {
auto* operand = CHECK_NOTNULL(instruction.phy_instr_operand().get());
return CHECK_NOTNULL(dynamic_cast<OpCallPhyInstrOperand*>(operand));
}
private:
static inline void InferTempStorageSize(OpCallPhyInstrOperand* operand) {
auto* tmp_tensor = operand->mut_call_ctx()->mut_tmp_tensor();
size_t temp_size =
operand->opkernel().InferTmpSize(&operand->call_ctx_, operand->user_opkernel());
tmp_tensor->set_tmp_buffer_size(temp_size);
}
static inline void TryInitOpKernelStateAndCache(OpCallPhyInstrOperand* operand,
DeviceCtx* device_ctx,
user_op::OpKernelState** state,
user_op::OpKernelCache** cache) {
OF_PROFILER_RANGE_GUARD("TryInitOpKernelStateAndCache");
if (likely(operand->op_interp_ctx().state)) {
*state = operand->op_interp_ctx().state.get();
// set state to nullptr so that state initialization in TryInitOpKernelStateAndCache will be
// skipped.
state = nullptr;
}
operand->mut_opkernel()->TryInitOpKernelStateAndCache(&operand->call_ctx_, device_ctx,
operand->user_opkernel(), state, cache);
}
static inline Maybe<void> AllocateOutputBlobsMemory(OpCallPhyInstrOperand* operand,
DeviceCtx* device_ctx) {
OF_PROFILER_RANGE_GUARD("AllocateOutputBlobsMemory");
for (const auto& blob_object : *operand->outputs()) {
JUST(blob_object->TryAllocateBlobBodyMemory(device_ctx));
}
return Maybe<void>::Ok();
}
static inline Maybe<void> TryAllocateTempStorage(OpCallPhyInstrOperand* operand,
DeviceCtx* device_ctx) {
OF_PROFILER_RANGE_GUARD("TryAllocateTempStorage");
auto* tmp_tensor = operand->mut_call_ctx()->mut_tmp_tensor();
size_t byte_size = tmp_tensor->tmp_buffer_size();
if (byte_size > 0) {
char* mem_ptr = nullptr;
JUST(device_ctx->mut_allocator()->Allocate(&mem_ptr, byte_size));
tmp_tensor->init_tmp_buffer_ptr(mem_ptr);
}
return Maybe<void>::Ok();
}
static inline void OpKernelCompute(OpCallPhyInstrOperand* operand, DeviceCtx* device_ctx,
user_op::OpKernelState* state, user_op::OpKernelCache* cache) {
auto* call_ctx = &operand->call_ctx_;
auto* user_kernel = operand->user_opkernel();
operand->mut_opkernel()->Compute(call_ctx, device_ctx, user_kernel, state, cache);
}
static inline void DeallocateTempStorage(OpCallPhyInstrOperand* operand, DeviceCtx* device_ctx) {
OF_PROFILER_RANGE_GUARD("DeallocateTempStorage");
auto* tmp_tensor = operand->mut_call_ctx()->mut_tmp_tensor();
device_ctx->mut_allocator()->Deallocate(tmp_tensor->mut_tmp_buffer_ptr(),
tmp_tensor->tmp_buffer_size());
}
};
Maybe<void> OpCallInstructionType::Prepare(vm::Instruction* instruction) const {
return OpCallInstructionUtil::Prepare(*instruction);
}
void OpCallInstructionType::Compute(vm::Instruction* instruction) const {
OpCallInstructionUtil::Compute(*instruction);
}
std::string OpCallInstructionType::DebugName(const vm::Instruction& instruction) const {
auto* operand = CHECK_NOTNULL(instruction.phy_instr_operand().get());
return CHECK_NOTNULL(dynamic_cast<OpCallPhyInstrOperand*>(operand))->opkernel().op_type_name()
+ ":OpCall";
}
} // namespace vm
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EAGER_OP_CALL_INSTRUCTION_TYPE_H_
#define ONEFLOW_CORE_EAGER_OP_CALL_INSTRUCTION_TYPE_H_
#include "oneflow/core/vm/instruction.h"
#include "oneflow/core/vm/instruction_type.h"
#include "oneflow/core/memory/memory_case.pb.h"
namespace oneflow {
namespace vm {
class OpCallInstructionType final : public vm::InstructionType {
public:
OpCallInstructionType() = default;
~OpCallInstructionType() = default;
Maybe<void> Prepare(vm::Instruction* instruction) const override;
void Compute(vm::Instruction* instruction) const override;
InstructionFuseType fuse_type() const override { return kEnableInstructionFuseAtAnyPosition; }
std::string DebugName(const vm::Instruction& instruction) const override;
protected:
private:
Maybe<void> MaybeCompute(vm::Instruction* instruction) const;
};
} // namespace vm
} // namespace oneflow
#endif // ONEFLOW_CORE_EAGER_OP_CALL_INSTRUCTION_TYPE_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/eager/op_call_phy_instr_operand.h"
#include "oneflow/user/kernels/stateful_opkernel.h"
#include "oneflow/core/eager/dev_vm_dep_object_consume_mode.h"
#include "oneflow/core/framework/stream_is_comm_net_stream.h"
#include "oneflow/core/vm/stream.h"
namespace oneflow {
namespace vm {
OpCallPhyInstrOperand::OpCallPhyInstrOperand(
vm::Stream* vm_stream, const std::shared_ptr<one::StatefulOpKernel>& opkernel,
const one::EagerBlobObjectListPtr& inputs, const one::EagerBlobObjectListPtr& outputs,
const std::shared_ptr<const one::ConsistentTensorInferResult>& consistent_tensor_infer_result,
const one::OpExprInterpContext& op_interp_ctx,
const one::DevVmDepObjectConsumeMode dev_vm_dep_object_consume_mode)
: vm_stream_(vm_stream),
call_ctx_(ComposedAttrMap(op_interp_ctx.attrs, opkernel->base_attrs()), inputs, outputs,
consistent_tensor_infer_result, op_interp_ctx, opkernel->mem_case()),
opkernel_(opkernel),
user_opkernel_(nullptr),
infer_tmp_size_fn_(nullptr),
need_temp_storage_(false),
dev_vm_dep_object_consume_mode_(dev_vm_dep_object_consume_mode),
input_dependences_(),
output_dependences_(),
is_all_outputs_pod_(false) {
ForEachConstMirroredObject(SetInserter(&input_dependences_));
ForEachMutMirroredObject(SetInserter(&output_dependences_));
ForEachMut2MirroredObject(SetInserter(&output_dependences_));
InitStreamSequentialDependence();
for (const auto& blob_object : *outputs) {
is_all_outputs_pod_ = is_all_outputs_pod_ && IsPODDataType(blob_object->data_type());
}
}
Maybe<void> OpCallPhyInstrOperand::Init() {
return mut_opkernel()->ChooseOpKernel(&call_ctx_, &user_opkernel_, &need_temp_storage_);
}
void OpCallPhyInstrOperand::ForEachConstMirroredObject(
const std::function<void(vm::MirroredObject* compute)>& DoEach) const {
const auto& input_list = inputs();
for (int64_t index : opkernel().input_tuple_indexes4const_ibns()) {
const auto& input = input_list->at(index);
DoEach(CHECK_JUST(input->compute_local_dep_object()));
}
}
void OpCallPhyInstrOperand::InitStreamSequentialDependence() {
auto* device_schedule_dep_object = vm_stream_->schedule_local_dep_object().get();
if (IsCommNetStream::Visit(vm_stream_->stream_role())) {
// Sequantialize nccl instructions to avoid deadlock
stream_sequential_dependence_ = device_schedule_dep_object;
} else {
// Sequantialize instructions to avoid explosive memory allocation of source ops
if (dev_vm_dep_object_consume_mode() == one::DevVmDepObjectConsumeMode::MUTABLE) {
stream_sequential_dependence_ = device_schedule_dep_object;
} else if (opkernel().input_tuple_indexes4const_ibns().empty()
&& opkernel().input_tuple_indexes4mut_ibns().empty()) {
stream_sequential_dependence_ = device_schedule_dep_object;
}
}
}
void OpCallPhyInstrOperand::ForEachMutMirroredObject(
const std::function<void(vm::MirroredObject* compute)>& DoEach) const {
const auto& opt_transport_dep_object = vm_stream_->transport_local_dep_object();
if (opt_transport_dep_object.has_value()) { DoEach(CHECK_JUST(opt_transport_dep_object)->get()); }
const auto& input_list = inputs();
for (int64_t index : opkernel().input_tuple_indexes4mut_ibns()) {
const auto& input = input_list->at(index);
DoEach(CHECK_JUST(input->compute_local_dep_object()));
}
const auto& output_list = outputs();
for (int64_t index : opkernel().output_tuple_indexes4mut_obns()) {
const auto& output = output_list->at(index);
DoEach(CHECK_JUST(output->compute_local_dep_object()));
}
}
void OpCallPhyInstrOperand::ForEachMut2MirroredObject(
const std::function<void(vm::MirroredObject* compute)>& DoEach) const {
const auto& output_list = outputs();
for (int64_t index : opkernel().output_tuple_indexes4mut2_obns()) {
const auto& output = output_list->at(index);
DoEach(CHECK_JUST(output->compute_local_dep_object()));
}
}
} // namespace vm
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EAGER_OP_CALL_PHY_INSTR_OPERAND_H_
#define ONEFLOW_CORE_EAGER_OP_CALL_PHY_INSTR_OPERAND_H_
#include "oneflow/core/vm/phy_instr_operand.h"
#include "oneflow/core/eager/call_context.h"
#include "oneflow/core/eager/dev_vm_dep_object_consume_mode.h"
#include "oneflow/core/framework/user_op_kernel_registry.h"
namespace oneflow {
namespace user_op {
class OpKernel;
} // namespace user_op
namespace vm {
class Stream;
struct OpCallInstructionUtil;
class OpCallPhyInstrOperand final : public vm::PhyInstrOperand {
public:
OpCallPhyInstrOperand(const OpCallPhyInstrOperand&) = delete;
OpCallPhyInstrOperand(OpCallPhyInstrOperand&&) = delete;
~OpCallPhyInstrOperand() override = default;
template<typename... Args>
static Maybe<OpCallPhyInstrOperand> New(Args&&... args) {
auto* ptr = new OpCallPhyInstrOperand(std::forward<Args>(args)...);
JUST(ptr->Init());
return std::shared_ptr<OpCallPhyInstrOperand>(ptr);
}
const one::StatefulOpKernel& opkernel() const { return *opkernel_; }
const one::EagerBlobObjectListPtr& inputs() const { return call_ctx_.inputs(); }
const one::EagerBlobObjectListPtr& outputs() const { return call_ctx_.outputs(); }
const AttrMap& attrs() const { return call_ctx_.op_interp_ctx().attrs; }
const one::OpExprInterpContext& op_interp_ctx() const { return call_ctx_.op_interp_ctx(); }
const one::DevVmDepObjectConsumeMode& dev_vm_dep_object_consume_mode() const {
return dev_vm_dep_object_consume_mode_;
}
bool is_all_outputs_pod() const { return is_all_outputs_pod_; }
one::StatefulOpKernel* mut_opkernel() { return opkernel_.get(); }
template<typename DoEachT>
Maybe<void> ForEachOutputTensor(const DoEachT& DoEach) {
for (const auto& output : *outputs()) { JUST(DoEach(output.get())); }
return Maybe<void>::Ok();
}
const DependenceVector& input_dependences() const override { return input_dependences_; }
const DependenceVector& output_dependences() const override { return output_dependences_; }
void ForEachConstMirroredObject(const std::function<void(vm::MirroredObject* compute)>&) const;
void ForEachMutMirroredObject(const std::function<void(vm::MirroredObject* compute)>&) const;
void ForEachMut2MirroredObject(const std::function<void(vm::MirroredObject* compute)>&) const;
bool need_temp_storage() const { return need_temp_storage_; }
const user_op::OpKernel* user_opkernel() const { return user_opkernel_; }
const user_op::InferTmpSizeFn& infer_tmp_size_fn() const { return *infer_tmp_size_fn_; }
const std::shared_ptr<const one::ConsistentTensorInferResult>& consistent_tensor_infer_result()
const {
return call_ctx_.consistent_tensor_infer_result();
}
eager::CallContext* mut_call_ctx() { return &call_ctx_; }
void ForEachInputEagerBlobObjects(void (*DoEach)(EagerBlobObject*)) const override {
for (const auto& eager_blob_object : *call_ctx_.inputs()) { DoEach(eager_blob_object.get()); }
}
private:
friend struct OpCallInstructionUtil;
OpCallPhyInstrOperand(
vm::Stream* vm_stream, const std::shared_ptr<one::StatefulOpKernel>& opkernel,
const one::EagerBlobObjectListPtr& inputs, const one::EagerBlobObjectListPtr& outputs,
const std::shared_ptr<const one::ConsistentTensorInferResult>& consistent_tensor_infer_result,
const one::OpExprInterpContext& op_interp_ctx,
const one::DevVmDepObjectConsumeMode dev_vm_dep_object_consume_mode);
Maybe<void> Init();
void InitStreamSequentialDependence();
vm::Stream* vm_stream_;
eager::CallContext call_ctx_;
std::shared_ptr<one::StatefulOpKernel> opkernel_;
const user_op::OpKernel* user_opkernel_;
const user_op::InferTmpSizeFn* infer_tmp_size_fn_;
bool need_temp_storage_;
const one::DevVmDepObjectConsumeMode dev_vm_dep_object_consume_mode_;
DependenceVector input_dependences_;
DependenceVector output_dependences_;
bool is_all_outputs_pod_;
};
} // namespace vm
} // namespace oneflow
#endif // ONEFLOW_CORE_EAGER_OP_CALL_PHY_INSTR_OPERAND_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EAGER_RELEASE_TENSOR_ARG_PHY_INSTR_OPERAND_H_
#define ONEFLOW_CORE_EAGER_RELEASE_TENSOR_ARG_PHY_INSTR_OPERAND_H_
#include <functional>
#include <memory>
#include "oneflow/core/intrusive/intrusive.h"
#include "oneflow/core/vm/phy_instr_operand.h"
#include "oneflow/core/eager/local_dep_object.h"
#include "oneflow/core/eager/eager_blob_object.h"
#include "oneflow/core/common/symbol.h"
#include "oneflow/core/common/optional.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/stream.h"
#include "oneflow/core/vm/stream.h"
namespace oneflow {
namespace vm {
class EagerBlobObject;
class ReleaseTensorArgPhyInstrOperand : public PhyInstrOperand {
public:
ReleaseTensorArgPhyInstrOperand(const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object,
const Optional<vm::Stream*>& stream)
: eager_blob_object_(eager_blob_object), output_dependences_() {
output_dependences_.push_back(CHECK_JUST(eager_blob_object->compute_local_dep_object()));
if (stream.has_value()) {
stream_sequential_dependence_ = CHECK_JUST(stream)->schedule_local_dep_object().get();
}
}
~ReleaseTensorArgPhyInstrOperand() override = default;
const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object() const {
return eager_blob_object_;
}
const DependenceVector& input_dependences() const override {
static thread_local DependenceVector empty{};
return empty;
}
const DependenceVector& output_dependences() const override { return output_dependences_; }
void ForEachInputEagerBlobObjects(void (*DoEach)(EagerBlobObject*)) const override {
DoEach(eager_blob_object_.get());
}
private:
std::shared_ptr<vm::EagerBlobObject> eager_blob_object_;
DependenceVector output_dependences_;
};
} // namespace vm
} // namespace oneflow
#endif // ONEFLOW_CORE_EAGER_RELEASE_TENSOR_ARG_PHY_INSTR_OPERAND_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EAGER_RELEASE_TENSOR_INSTRUCTION_TYPE_H_
#define ONEFLOW_CORE_EAGER_RELEASE_TENSOR_INSTRUCTION_TYPE_H_
#include "oneflow/core/vm/instruction.h"
#include "oneflow/core/vm/instruction_type.h"
#include "oneflow/core/vm/ep_optional_event_record_status_querier.h"
#include "oneflow/core/eager/release_tensor_arg_phy_instr_operand.h"
#include "oneflow/core/eager/eager_blob_object.h"
#include "oneflow/core/common/stream_role.h"
#include "oneflow/core/common/singleton_ptr.h"
namespace oneflow {
namespace vm {
class ReleaseTensorInstructionType : public vm::InstructionType {
public:
ReleaseTensorInstructionType() = default;
~ReleaseTensorInstructionType() override = default;
InstructionFuseType fuse_type() const override { return kEnableInstructionFuseAtAnyPosition; }
std::string DebugName(const vm::Instruction& instruction) const override {
return "ReleaseTensor";
}
Maybe<void> Prepare(vm::Instruction* instruction) const override {
const auto& eager_blob_object = GetEagerBlobObject(*instruction);
DataType data_type = eager_blob_object->data_type();
if (IsPODDataType(data_type)) { Release(eager_blob_object); }
return Maybe<void>::Ok();
}
void Compute(vm::Instruction* instruction) const override {
const auto& eager_blob_object = GetEagerBlobObject(*instruction);
DataType data_type = eager_blob_object->data_type();
if (!IsPODDataType(data_type)) { Release(eager_blob_object); }
}
void InitInstructionStatus(Instruction* instruction) const override {
auto* status_buffer = instruction->mut_status_buffer();
auto* stream = instruction->mut_stream();
instruction->stream_type().InitInstructionStatus(*stream, status_buffer);
auto* data_ptr = status_buffer->mut_buffer();
EpOptionalEventRecordStatusQuerier::MutCast(data_ptr)->reset_ep_event(nullptr);
}
private:
const std::shared_ptr<vm::EagerBlobObject>& GetEagerBlobObject(
const vm::Instruction& instruction) const {
const auto& phy_instr_operand = instruction.phy_instr_operand();
CHECK(static_cast<bool>(phy_instr_operand));
const auto* ptr =
dynamic_cast<const vm::ReleaseTensorArgPhyInstrOperand*>(phy_instr_operand.get());
CHECK_NOTNULL(ptr);
return ptr->eager_blob_object();
}
void Release(const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object) const {
CHECK_JUST(eager_blob_object->DeallocateBlobDataPtr());
}
};
} // namespace vm
struct GetReleaseInstructionType : public StreamRoleVisitor<GetReleaseInstructionType> {
static Maybe<const vm::InstructionType*> VisitCompute(DeviceType device_type) {
return SingletonPtr<vm::ReleaseTensorInstructionType>();
}
static Maybe<const vm::InstructionType*> VisitHost2Device(DeviceType device_type) {
return SingletonPtr<vm::ReleaseTensorInstructionType>();
}
static Maybe<const vm::InstructionType*> VisitDevice2Host(DeviceType device_type) {
return SingletonPtr<vm::ReleaseTensorInstructionType>();
}
static Maybe<const vm::InstructionType*> VisitSyncedLaunchedCommNet(DeviceType device_type) {
return SingletonPtr<vm::ReleaseTensorInstructionType>();
}
static Maybe<const vm::InstructionType*> VisitAsyncedLaunchedCommNet(DeviceType device_type) {
return SingletonPtr<vm::ReleaseTensorInstructionType>();
}
static Maybe<const vm::InstructionType*> VisitBarrier(DeviceType device_type) {
UNIMPLEMENTED_THEN_RETURN();
}
static Maybe<const vm::InstructionType*> VisitCriticalSection(DeviceType device_type) {
UNIMPLEMENTED_THEN_RETURN();
}
static Maybe<const vm::InstructionType*> VisitLazyJobLauncher(DeviceType device_type) {
UNIMPLEMENTED_THEN_RETURN();
}
static Maybe<const vm::InstructionType*> VisitPinnedCompute(DeviceType device_type) {
return VisitCompute(device_type);
}
};
} // namespace oneflow
#endif // ONEFLOW_CORE_EAGER_RELEASE_TENSOR_INSTRUCTION_TYPE_H_
...@@ -75,6 +75,9 @@ class Cache { ...@@ -75,6 +75,9 @@ class Cache {
} }
virtual void Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index, virtual void Dump(ep::Stream* stream, uint64_t start_key_index, uint64_t end_key_index,
uint32_t* n_dumped, void* keys, void* values) = 0; uint32_t* n_dumped, void* keys, void* values) = 0;
virtual void ClearDirtyFlags() = 0;
virtual void Clear() = 0; virtual void Clear() = 0;
}; };
......
...@@ -462,7 +462,7 @@ TEST(Cache, FullCache) { ...@@ -462,7 +462,7 @@ TEST(Cache, FullCache) {
// TestCache(cache.get(), line_size); // TestCache(cache.get(), line_size);
// } // }
#endif #endif // WITH_ROCM
} // namespace } // namespace
......
...@@ -45,22 +45,26 @@ class CacheKeyValueStoreImpl : public KeyValueStore { ...@@ -45,22 +45,26 @@ class CacheKeyValueStoreImpl : public KeyValueStore {
OF_DISALLOW_COPY_AND_MOVE(CacheKeyValueStoreImpl); OF_DISALLOW_COPY_AND_MOVE(CacheKeyValueStoreImpl);
CacheKeyValueStoreImpl(std::unique_ptr<KeyValueStore>&& store, std::unique_ptr<Cache>&& cache) CacheKeyValueStoreImpl(std::unique_ptr<KeyValueStore>&& store, std::unique_ptr<Cache>&& cache)
: store_(std::move(store)), cache_(std::move(cache)), synced_(true), max_query_length_(0) { : store_(std::move(store)), cache_(std::move(cache)), synced_(true), max_query_length_(0) {
OF_CUDA_CHECK(cudaGetDevice(&device_index_)); OF_CUDA_CHECK(GPU(GetDevice)(&device_index_));
CHECK_EQ(store_->KeySize(), cache_->KeySize()); CHECK_EQ(store_->KeySize(), cache_->KeySize());
CHECK_EQ(store_->ValueSize(), cache_->ValueSize()); CHECK_EQ(store_->ValueSize(), cache_->ValueSize());
OF_CUDA_CHECK(cudaMalloc(&num_buffer_, sizeof(uint32_t))); OF_CUDA_CHECK(GPU(Malloc)(&num_buffer_, sizeof(uint32_t)));
#ifdef WITH_ROCM
OF_CUDA_CHECK(hipMallocHost(reinterpret_cast<void **>(&host_num_buffer_), sizeof(uint32_t)));
#else
OF_CUDA_CHECK(cudaMallocHost(&host_num_buffer_, sizeof(uint32_t))); OF_CUDA_CHECK(cudaMallocHost(&host_num_buffer_, sizeof(uint32_t)));
#endif
num_elems_per_value_ = store_->ValueSize() / sizeof(Elem); num_elems_per_value_ = store_->ValueSize() / sizeof(Elem);
} }
~CacheKeyValueStoreImpl() { ~CacheKeyValueStoreImpl() {
CudaCurrentDeviceGuard guard(device_index_); CudaCurrentDeviceGuard guard(device_index_);
OF_CUDA_CHECK(cudaFree(num_buffer_)); OF_CUDA_CHECK(GPU(Free)(num_buffer_));
OF_CUDA_CHECK(cudaFreeHost(host_num_buffer_)); OF_CUDA_CHECK(GPU(FreeHost)(host_num_buffer_));
if (max_query_length_ != 0) { if (max_query_length_ != 0) {
OF_CUDA_CHECK(cudaFree(keys_buffer_)); OF_CUDA_CHECK(GPU(Free)(keys_buffer_));
OF_CUDA_CHECK(cudaFree(values_buffer_)); OF_CUDA_CHECK(GPU(Free)(values_buffer_));
OF_CUDA_CHECK(cudaFree(indices_buffer0_)); OF_CUDA_CHECK(GPU(Free)(indices_buffer0_));
OF_CUDA_CHECK(cudaFree(indices_buffer1_)); OF_CUDA_CHECK(GPU(Free)(indices_buffer1_));
} }
cache_.reset(); cache_.reset();
store_.reset(); store_.reset();
...@@ -76,15 +80,15 @@ class CacheKeyValueStoreImpl : public KeyValueStore { ...@@ -76,15 +80,15 @@ class CacheKeyValueStoreImpl : public KeyValueStore {
if (query_length > cache_->MaxQueryLength()) { cache_->ReserveQueryLength(query_length); } if (query_length > cache_->MaxQueryLength()) { cache_->ReserveQueryLength(query_length); }
if (query_length > store_->MaxQueryLength()) { store_->ReserveQueryLength(query_length); } if (query_length > store_->MaxQueryLength()) { store_->ReserveQueryLength(query_length); }
if (max_query_length_ != 0) { if (max_query_length_ != 0) {
OF_CUDA_CHECK(cudaFree(keys_buffer_)); OF_CUDA_CHECK(GPU(Free)(keys_buffer_));
OF_CUDA_CHECK(cudaFree(values_buffer_)); OF_CUDA_CHECK(GPU(Free)(values_buffer_));
OF_CUDA_CHECK(cudaFree(indices_buffer0_)); OF_CUDA_CHECK(GPU(Free)(indices_buffer0_));
OF_CUDA_CHECK(cudaFree(indices_buffer1_)); OF_CUDA_CHECK(GPU(Free)(indices_buffer1_));
} }
OF_CUDA_CHECK(cudaMalloc(&keys_buffer_, query_length * store_->KeySize())); OF_CUDA_CHECK(GPU(Malloc)(&keys_buffer_, query_length * store_->KeySize()));
OF_CUDA_CHECK(cudaMalloc(&values_buffer_, query_length * store_->ValueSize())); OF_CUDA_CHECK(GPU(Malloc)(&values_buffer_, query_length * store_->ValueSize()));
OF_CUDA_CHECK(cudaMalloc(&indices_buffer0_, query_length * sizeof(uint32_t))); OF_CUDA_CHECK(GPU(Malloc)(&indices_buffer0_, query_length * sizeof(uint32_t)));
OF_CUDA_CHECK(cudaMalloc(&indices_buffer1_, query_length * sizeof(uint32_t))); OF_CUDA_CHECK(GPU(Malloc)(&indices_buffer1_, query_length * sizeof(uint32_t)));
max_query_length_ = query_length; max_query_length_ = query_length;
} }
...@@ -136,17 +140,17 @@ void CacheKeyValueStoreImpl<Key, Elem>::Get(ep::Stream* stream, uint32_t num_key ...@@ -136,17 +140,17 @@ void CacheKeyValueStoreImpl<Key, Elem>::Get(ep::Stream* stream, uint32_t num_key
} else { } else {
cache_->Get(stream, num_keys, keys, values, num_buffer_, keys_buffer_, indices_buffer0_); cache_->Get(stream, num_keys, keys, values, num_buffer_, keys_buffer_, indices_buffer0_);
} }
OF_CUDA_CHECK(cudaMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t), cudaMemcpyDefault, OF_CUDA_CHECK(GPU(MemcpyAsync)(host_num_buffer_, num_buffer_, sizeof(uint32_t), GPU(MemcpyDefault),
cuda_stream->cuda_stream())); cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync()); CHECK_JUST(cuda_stream->Sync());
const uint32_t num_cache_missing = *host_num_buffer_; const uint32_t num_cache_missing = *host_num_buffer_;
if (num_cache_missing == 0) { if (num_cache_missing == 0) {
OF_CUDA_CHECK(cudaMemsetAsync(n_missing, 0, sizeof(uint32_t), OF_CUDA_CHECK(GPU(MemsetAsync)(n_missing, 0, sizeof(uint32_t),
stream->As<ep::CudaStream>()->cuda_stream())); stream->As<ep::CudaStream>()->cuda_stream()));
return; return;
} }
store_->Get(stream, num_cache_missing, keys_buffer_, values_buffer_, n_missing, indices_buffer1_); store_->Get(stream, num_cache_missing, keys_buffer_, values_buffer_, n_missing, indices_buffer1_);
OF_CUDA_CHECK(cudaMemcpyAsync(host_num_buffer_, n_missing, sizeof(uint32_t), cudaMemcpyDefault, OF_CUDA_CHECK(GPU(MemcpyAsync)(host_num_buffer_, n_missing, sizeof(uint32_t), GPU(MemcpyDefault),
cuda_stream->cuda_stream())); cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync()); CHECK_JUST(cuda_stream->Sync());
const uint32_t num_store_missing = *host_num_buffer_; const uint32_t num_store_missing = *host_num_buffer_;
...@@ -173,9 +177,12 @@ void CacheKeyValueStoreImpl<Key, Elem>::Put(ep::Stream* stream, uint32_t num_key ...@@ -173,9 +177,12 @@ void CacheKeyValueStoreImpl<Key, Elem>::Put(ep::Stream* stream, uint32_t num_key
std::lock_guard<std::recursive_mutex> lock(mutex_); std::lock_guard<std::recursive_mutex> lock(mutex_);
synced_ = false; synced_ = false;
auto cuda_stream = stream->As<ep::CudaStream>(); auto cuda_stream = stream->As<ep::CudaStream>();
if (cache_->Policy() != CacheOptions::Policy::kFull) {
OF_CUDA_CHECK(GPU(MemsetAsync)(num_buffer_, 0, sizeof(uint32_t), cuda_stream->cuda_stream()));
}
cache_->Put(stream, num_keys, keys, values, num_buffer_, keys_buffer_, values_buffer_); cache_->Put(stream, num_keys, keys, values, num_buffer_, keys_buffer_, values_buffer_);
if (cache_->Policy() == CacheOptions::Policy::kFull) { return; } if (cache_->Policy() == CacheOptions::Policy::kFull) { return; }
OF_CUDA_CHECK(cudaMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t), cudaMemcpyDefault, OF_CUDA_CHECK(GPU(MemcpyAsync)(host_num_buffer_, num_buffer_, sizeof(uint32_t), GPU(MemcpyDefault),
cuda_stream->cuda_stream())); cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync()); CHECK_JUST(cuda_stream->Sync());
store_->Put(stream, *host_num_buffer_, keys_buffer_, values_buffer_); store_->Put(stream, *host_num_buffer_, keys_buffer_, values_buffer_);
...@@ -187,6 +194,10 @@ void CacheKeyValueStoreImpl<Key, Elem>::FusedHalfUpdatePut(ep::Stream* stream, u ...@@ -187,6 +194,10 @@ void CacheKeyValueStoreImpl<Key, Elem>::FusedHalfUpdatePut(ep::Stream* stream, u
const void* update, const float* lr, const void* update, const float* lr,
float scale) { float scale) {
std::lock_guard<std::recursive_mutex> lock(mutex_); std::lock_guard<std::recursive_mutex> lock(mutex_);
if (cache_->Policy() != CacheOptions::Policy::kFull) {
OF_CUDA_CHECK(GPU(MemsetAsync)(num_buffer_, 0, sizeof(uint32_t),
stream->As<ep::CudaStream>()->cuda_stream()));
}
if (cache_->Policy() != CacheOptions::Policy::kFull || cache_->ValueType() != DataType::kFloat) { if (cache_->Policy() != CacheOptions::Policy::kFull || cache_->ValueType() != DataType::kFloat) {
UNIMPLEMENTED(); UNIMPLEMENTED();
} }
...@@ -221,17 +232,13 @@ void CacheKeyValueStoreImpl<Key, Elem>::LoadSnapshot( ...@@ -221,17 +232,13 @@ void CacheKeyValueStoreImpl<Key, Elem>::LoadSnapshot(
auto* cuda_stream = stream->As<ep::CudaStream>(); auto* cuda_stream = stream->As<ep::CudaStream>();
while (true) { while (true) {
iter->NextN(stream, max_query_length_, num_buffer_, keys_buffer_, values_buffer_); iter->NextN(stream, max_query_length_, num_buffer_, keys_buffer_, values_buffer_);
OF_CUDA_CHECK(cudaDeviceSynchronize()); OF_CUDA_CHECK(GPU(DeviceSynchronize)());
OF_CUDA_CHECK(cudaMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t), OF_CUDA_CHECK(GPU(MemcpyAsync)(host_num_buffer_, num_buffer_, sizeof(uint32_t),
cudaMemcpyDefault, cuda_stream->cuda_stream())); GPU(MemcpyDefault), cuda_stream->cuda_stream()));
CHECK_JUST(stream->Sync()); CHECK_JUST(stream->Sync());
if (*host_num_buffer_ == 0) { return; } if (*host_num_buffer_ == 0) { return; }
cache_->Put(stream, *host_num_buffer_, keys_buffer_, values_buffer_, num_buffer_, nullptr, cache_->Put(stream, *host_num_buffer_, keys_buffer_, values_buffer_, num_buffer_, nullptr,
nullptr); nullptr);
OF_CUDA_CHECK(cudaMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t),
cudaMemcpyDefault, cuda_stream->cuda_stream()));
CHECK_JUST(stream->Sync());
CHECK_EQ(*host_num_buffer_, 0);
} }
} }
if (Hook) { if (Hook) {
...@@ -267,13 +274,14 @@ void CacheKeyValueStoreImpl<Key, Elem>::SyncCacheToStore() { ...@@ -267,13 +274,14 @@ void CacheKeyValueStoreImpl<Key, Elem>::SyncCacheToStore() {
cache_->Dump(stream, start_key_index, cache_->Dump(stream, start_key_index,
std::min(start_key_index + max_query_length_, dump_capacity), num_buffer_, std::min(start_key_index + max_query_length_, dump_capacity), num_buffer_,
keys_buffer_, values_buffer_); keys_buffer_, values_buffer_);
OF_CUDA_CHECK(cudaMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t), OF_CUDA_CHECK(GPU(MemcpyAsync)(host_num_buffer_, num_buffer_, sizeof(uint32_t),
cudaMemcpyDefault, cuda_stream->cuda_stream())); GPU(MemcpyDefault), cuda_stream->cuda_stream()));
CHECK_JUST(stream->Sync()); CHECK_JUST(stream->Sync());
if (*host_num_buffer_ == 0) { continue; } if (*host_num_buffer_ == 0) { continue; }
store_->Put(stream, *host_num_buffer_, keys_buffer_, values_buffer_); store_->Put(stream, *host_num_buffer_, keys_buffer_, values_buffer_);
CHECK_JUST(stream->Sync()); CHECK_JUST(stream->Sync());
} }
cache_->ClearDirtyFlags();
device->DestroyStream(stream); device->DestroyStream(stream);
synced_ = true; synced_ = true;
} }
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "hip/hip_runtime.h"
#include "oneflow/core/embedding/cached_key_value_store.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/ep/include/device_manager_registry.h"
namespace oneflow {
namespace embedding {
namespace {
template<typename Key, typename Elem>
__global__ void PostStoreGetKernel(uint32_t num_cache_missing, uint32_t num_store_missing,
uint32_t num_elems_per_value,
const uint32_t* cache_missing_indices,
const uint32_t* store_missing_indices, const Elem* store_values,
Elem* values, uint32_t* missing_indices) {
const uint32_t num_cache_missing_elem = num_cache_missing * num_elems_per_value;
CUDA_1D_KERNEL_LOOP_T(uint32_t, i, num_cache_missing_elem) {
const uint32_t value_index = i / num_elems_per_value;
const uint32_t elem_index = i - value_index * num_elems_per_value;
values[cache_missing_indices[value_index] * num_elems_per_value + elem_index] = store_values[i];
}
CUDA_1D_KERNEL_LOOP_T(uint32_t, i, num_store_missing) {
missing_indices[i] = cache_missing_indices[store_missing_indices[i]];
}
}
template<typename Key, typename Elem>
class CacheKeyValueStoreImpl : public KeyValueStore {
public:
OF_DISALLOW_COPY_AND_MOVE(CacheKeyValueStoreImpl);
CacheKeyValueStoreImpl(std::unique_ptr<KeyValueStore>&& store, std::unique_ptr<Cache>&& cache)
: store_(std::move(store)), cache_(std::move(cache)), synced_(true), max_query_length_(0) {
OF_CUDA_CHECK(hipGetDevice(&device_index_));
CHECK_EQ(store_->KeySize(), cache_->KeySize());
CHECK_EQ(store_->ValueSize(), cache_->ValueSize());
OF_CUDA_CHECK(hipMalloc(&num_buffer_, sizeof(uint32_t)));
OF_CUDA_CHECK(hipMallocHost(reinterpret_cast<void **>(&host_num_buffer_), sizeof(uint32_t)));
num_elems_per_value_ = store_->ValueSize() / sizeof(Elem);
}
~CacheKeyValueStoreImpl() {
CudaCurrentDeviceGuard guard(device_index_);
OF_CUDA_CHECK(hipFree(num_buffer_));
OF_CUDA_CHECK(hipHostFree(host_num_buffer_));
if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipFree(keys_buffer_));
OF_CUDA_CHECK(hipFree(values_buffer_));
OF_CUDA_CHECK(hipFree(indices_buffer0_));
OF_CUDA_CHECK(hipFree(indices_buffer1_));
}
cache_.reset();
store_.reset();
}
uint32_t KeySize() const override { return store_->KeySize(); }
uint32_t ValueSize() const override { return store_->ValueSize(); }
uint32_t MaxQueryLength() const override { return max_query_length_; }
void ReserveQueryLength(uint32_t query_length) override {
CudaCurrentDeviceGuard guard(device_index_);
if (query_length <= max_query_length_) { return; }
if (query_length > cache_->MaxQueryLength()) { cache_->ReserveQueryLength(query_length); }
if (query_length > store_->MaxQueryLength()) { store_->ReserveQueryLength(query_length); }
if (max_query_length_ != 0) {
OF_CUDA_CHECK(hipFree(keys_buffer_));
OF_CUDA_CHECK(hipFree(values_buffer_));
OF_CUDA_CHECK(hipFree(indices_buffer0_));
OF_CUDA_CHECK(hipFree(indices_buffer1_));
}
OF_CUDA_CHECK(hipMalloc(&keys_buffer_, query_length * store_->KeySize()));
OF_CUDA_CHECK(hipMalloc(&values_buffer_, query_length * store_->ValueSize()));
OF_CUDA_CHECK(hipMalloc(&indices_buffer0_, query_length * sizeof(uint32_t)));
OF_CUDA_CHECK(hipMalloc(&indices_buffer1_, query_length * sizeof(uint32_t)));
max_query_length_ = query_length;
}
void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values,
uint32_t* n_missing, uint32_t* missing_indices) override;
void Get(ep::Stream* stream, uint32_t num_keys, const void* keys, void* values,
uint8_t* mask) override;
void Put(ep::Stream* stream, uint32_t num_keys, const void* keys, const void* values) override;
void FusedHalfUpdatePut(ep::Stream* stream, uint32_t n_keys, const void* keys, const void* values,
const void* update, const float* lr, float scale) override;
bool IsFusionSupported() override {
return cache_->Policy() == CacheOptions::Policy::kFull
&& cache_->ValueType() == DataType::kFloat;
}
bool SnapshotExists(const std::string& name) override;
void LoadSnapshot(const std::string& name) override;
void SaveSnapshot(const std::string& name) override;
void LoadSnapshot(const std::string& name,
const std::function<void(KVIterator* iter)>& Hook) override;
private:
void SyncCacheToStore();
std::unique_ptr<KeyValueStore> store_;
std::unique_ptr<Cache> cache_;
uint32_t* num_buffer_{};
uint32_t* host_num_buffer_{};
Key* keys_buffer_{};
Elem* values_buffer_{};
uint32_t* indices_buffer0_{};
uint32_t* indices_buffer1_{};
int device_index_{};
uint32_t max_query_length_;
uint32_t num_elems_per_value_{};
std::recursive_mutex mutex_;
bool synced_;
};
template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::Get(ep::Stream* stream, uint32_t num_keys, const void* keys,
void* values, uint32_t* n_missing,
uint32_t* missing_indices) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
auto cuda_stream = stream->As<ep::CudaStream>();
if (cache_->Policy() == CacheOptions::Policy::kFull) {
cache_->Get(stream, num_keys, keys, values, n_missing, keys_buffer_, missing_indices);
return;
} else {
cache_->Get(stream, num_keys, keys, values, num_buffer_, keys_buffer_, indices_buffer0_);
}
OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync());
const uint32_t num_cache_missing = *host_num_buffer_;
if (num_cache_missing == 0) {
OF_CUDA_CHECK(hipMemsetAsync(n_missing, 0, sizeof(uint32_t),
stream->As<ep::CudaStream>()->cuda_stream()));
return;
}
store_->Get(stream, num_cache_missing, keys_buffer_, values_buffer_, n_missing, indices_buffer1_);
OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, n_missing, sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync());
const uint32_t num_store_missing = *host_num_buffer_;
RUN_CUDA_KERNEL((PostStoreGetKernel<Key, Elem>), stream, num_cache_missing * num_elems_per_value_,
num_cache_missing, num_store_missing, num_elems_per_value_, indices_buffer0_,
indices_buffer1_, values_buffer_, static_cast<Elem*>(values), missing_indices);
}
template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::Get(ep::Stream* stream, uint32_t num_keys, const void* keys,
void* values, uint8_t* mask) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
if (cache_->Policy() == CacheOptions::Policy::kFull) {
cache_->Get(stream, num_keys, keys, values, mask);
return;
} else {
UNIMPLEMENTED();
}
}
template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::Put(ep::Stream* stream, uint32_t num_keys, const void* keys,
const void* values) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
synced_ = false;
auto cuda_stream = stream->As<ep::CudaStream>();
cache_->Put(stream, num_keys, keys, values, num_buffer_, keys_buffer_, values_buffer_);
if (cache_->Policy() == CacheOptions::Policy::kFull) { return; }
OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t), hipMemcpyDefault,
cuda_stream->cuda_stream()));
CHECK_JUST(cuda_stream->Sync());
store_->Put(stream, *host_num_buffer_, keys_buffer_, values_buffer_);
}
template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::FusedHalfUpdatePut(ep::Stream* stream, uint32_t num_keys,
const void* keys, const void* values,
const void* update, const float* lr,
float scale) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
if (cache_->Policy() != CacheOptions::Policy::kFull || cache_->ValueType() != DataType::kFloat) {
UNIMPLEMENTED();
}
synced_ = false;
cache_->FusedHalfUpdatePut(stream, num_keys, keys, values, update, lr, scale, num_buffer_,
keys_buffer_, values_buffer_);
}
template<typename Key, typename Elem>
bool CacheKeyValueStoreImpl<Key, Elem>::SnapshotExists(const std::string& name) {
return store_->SnapshotExists(name);
}
template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::LoadSnapshot(const std::string& name) {
LoadSnapshot(name, nullptr);
}
template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::LoadSnapshot(
const std::string& name, const std::function<void(KVIterator* iter)>& Hook) {
CudaCurrentDeviceGuard guard(device_index_);
std::lock_guard<std::recursive_mutex> lock(mutex_);
CHECK_GT(max_query_length_, 0);
cache_->Clear();
auto device =
Singleton<ep::DeviceManagerRegistry>::Get()->GetDevice(DeviceType::kCUDA, device_index_);
CHECK(device);
auto* stream = device->CreateStream();
store_->LoadSnapshot(name, [&](KVIterator* iter) {
if (cache_->Policy() == CacheOptions::Policy::kFull) {
auto* cuda_stream = stream->As<ep::CudaStream>();
while (true) {
iter->NextN(stream, max_query_length_, num_buffer_, keys_buffer_, values_buffer_);
OF_CUDA_CHECK(hipDeviceSynchronize());
OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t),
hipMemcpyDefault, cuda_stream->cuda_stream()));
CHECK_JUST(stream->Sync());
if (*host_num_buffer_ == 0) { return; }
cache_->Put(stream, *host_num_buffer_, keys_buffer_, values_buffer_, num_buffer_, nullptr,
nullptr);
OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t),
hipMemcpyDefault, cuda_stream->cuda_stream()));
CHECK_JUST(stream->Sync());
CHECK_EQ(*host_num_buffer_, 0);
}
}
if (Hook) {
iter->Reset();
Hook(iter);
}
});
device->DestroyStream(stream);
store_->LoadSnapshot(name);
}
template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::SaveSnapshot(const std::string& name) {
CudaCurrentDeviceGuard guard(device_index_);
std::lock_guard<std::recursive_mutex> lock(mutex_);
SyncCacheToStore();
store_->SaveSnapshot(name);
}
template<typename Key, typename Elem>
void CacheKeyValueStoreImpl<Key, Elem>::SyncCacheToStore() {
if (synced_) { return; }
CudaCurrentDeviceGuard guard(device_index_);
auto device =
Singleton<ep::DeviceManagerRegistry>::Get()->GetDevice(DeviceType::kCUDA, device_index_);
CHECK(device);
auto* stream = device->CreateStream();
auto* cuda_stream = stream->As<ep::CudaStream>();
const uint64_t dump_capacity = cache_->DumpCapacity();
CHECK_GT(max_query_length_, 0);
for (uint64_t start_key_index = 0; start_key_index < dump_capacity;
start_key_index += max_query_length_) {
cache_->Dump(stream, start_key_index,
std::min(start_key_index + max_query_length_, dump_capacity), num_buffer_,
keys_buffer_, values_buffer_);
OF_CUDA_CHECK(hipMemcpyAsync(host_num_buffer_, num_buffer_, sizeof(uint32_t),
hipMemcpyDefault, cuda_stream->cuda_stream()));
CHECK_JUST(stream->Sync());
if (*host_num_buffer_ == 0) { continue; }
store_->Put(stream, *host_num_buffer_, keys_buffer_, values_buffer_);
CHECK_JUST(stream->Sync());
}
device->DestroyStream(stream);
synced_ = true;
}
template<typename Key>
std::unique_ptr<KeyValueStore> DispatchElemType(std::unique_ptr<KeyValueStore>&& store,
std::unique_ptr<Cache>&& cache) {
const uint32_t value_size = store->ValueSize();
if (value_size % sizeof(uint4) == 0) {
return std::unique_ptr<KeyValueStore>(
new CacheKeyValueStoreImpl<Key, uint4>(std::move(store), std::move(cache)));
} else if (value_size % sizeof(uint64_t) == 0) {
return std::unique_ptr<KeyValueStore>(
new CacheKeyValueStoreImpl<Key, uint64_t>(std::move(store), std::move(cache)));
} else if (value_size % sizeof(uint32_t) == 0) {
return std::unique_ptr<KeyValueStore>(
new CacheKeyValueStoreImpl<Key, uint32_t>(std::move(store), std::move(cache)));
} else if (value_size % sizeof(uint16_t) == 0) {
return std::unique_ptr<KeyValueStore>(
new CacheKeyValueStoreImpl<Key, uint16_t>(std::move(store), std::move(cache)));
} else {
return std::unique_ptr<KeyValueStore>(
new CacheKeyValueStoreImpl<Key, uint8_t>(std::move(store), std::move(cache)));
}
}
std::unique_ptr<KeyValueStore> DispatchKeyType(std::unique_ptr<KeyValueStore>&& store,
std::unique_ptr<Cache>&& cache) {
const uint32_t key_size = store->KeySize();
if (key_size == 4) {
return DispatchElemType<uint32_t>(std::move(store), std::move(cache));
} else if (key_size == 8) {
return DispatchElemType<uint64_t>(std::move(store), std::move(cache));
} else {
UNIMPLEMENTED();
return nullptr;
}
}
} // namespace
std::unique_ptr<KeyValueStore> NewCachedKeyValueStore(std::unique_ptr<KeyValueStore>&& store,
std::unique_ptr<Cache>&& cache) {
return DispatchKeyType(std::move(store), std::move(cache));
}
} // namespace embedding
} // namespace oneflow
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment