Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
92472152
Commit
92472152
authored
Dec 08, 2025
by
pengcheng888
Browse files
issue/713 - 为c++添加 RowParallelLinear 和 ColParallelLinear
parent
793696d8
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
343 additions
and
40 deletions
+343
-40
include/infinicore/nn/linear.hpp
include/infinicore/nn/linear.hpp
+57
-7
src/infinicore-test/test_nn_module.cc
src/infinicore-test/test_nn_module.cc
+156
-0
src/infinicore-test/test_nn_module.h
src/infinicore-test/test_nn_module.h
+1
-0
src/infinicore/nn/linear.cc
src/infinicore/nn/linear.cc
+129
-33
No files found.
include/infinicore/nn/linear.hpp
View file @
92472152
#pragma once
#pragma once
#include "module.hpp"
#include "../ops.hpp"
#include "../ops.hpp"
#include "module.hpp"
#include <infiniccl.h>
namespace
infinicore
::
nn
{
namespace
infinicore
::
nn
{
class
Linear
:
public
Module
{
class
Base
Linear
:
public
Module
{
public:
public:
Linear
(
size_t
in_features
,
size_t
out_features
,
bool
bias
=
true
,
const
DataType
&
dtype
=
DataType
::
F32
,
const
Device
&
device
=
Device
());
BaseLinear
(
size_t
in_features
,
size_t
out_features
,
bool
bias
=
true
,
const
DataType
&
dtype
=
DataType
::
F32
,
const
Device
&
device
=
Device
());
// Forward pass: output = input @ weight.T + bias
// Forward pass: output = input @ weight.T + bias
Tensor
forward
(
Tensor
&
input
)
const
;
Tensor
forward
(
Tensor
&
input
)
const
;
...
@@ -22,9 +24,6 @@ public:
...
@@ -22,9 +24,6 @@ public:
bool
has_bias
()
const
{
return
has_bias_
;
}
bool
has_bias
()
const
{
return
has_bias_
;
}
DataType
dtype
()
const
{
return
dtype_
;
}
DataType
dtype
()
const
{
return
dtype_
;
}
// String representation
std
::
string
extra_repr
()
const
;
// Accessors for parameters
// Accessors for parameters
Tensor
weight
()
const
{
return
weight_
;
}
Tensor
weight
()
const
{
return
weight_
;
}
Tensor
bias
()
const
{
return
bias_
;
}
Tensor
bias
()
const
{
return
bias_
;
}
...
@@ -34,7 +33,7 @@ protected:
...
@@ -34,7 +33,7 @@ protected:
INFINICORE_NN_PARAMETER
(
weight
);
INFINICORE_NN_PARAMETER
(
weight
);
INFINICORE_NN_PARAMETER
(
bias
);
INFINICORE_NN_PARAMETER
(
bias
);
pr
ivate
:
pr
otected
:
// Helper method for common forward computation
// Helper method for common forward computation
Tensor
compute_linear
(
Tensor
&
input
)
const
;
Tensor
compute_linear
(
Tensor
&
input
)
const
;
...
@@ -45,3 +44,54 @@ private:
...
@@ -45,3 +44,54 @@ private:
};
};
}
// namespace infinicore::nn
}
// namespace infinicore::nn
namespace
infinicore
::
nn
{
class
Linear
:
public
BaseLinear
{
public:
Linear
(
size_t
in_features
,
size_t
out_features
,
bool
bias
=
true
,
const
DataType
&
dtype
=
DataType
::
F32
,
const
Device
&
device
=
Device
());
// Forward pass: output = input @ weight.T + bias
Tensor
forward
(
Tensor
&
input
)
const
;
// String representation
std
::
string
extra_repr
()
const
;
};
class
ColumnParallelLinear
:
public
BaseLinear
{
public:
ColumnParallelLinear
(
size_t
in_features
,
size_t
out_features
,
bool
bias
=
true
,
const
DataType
&
dtype
=
DataType
::
F32
,
const
Device
&
device
=
Device
(),
Size
tp_rank
=
0
,
Size
tp_size
=
1
);
// Forward pass: output = input @ weight.T + bias
Tensor
forward
(
Tensor
&
input
)
const
;
// String representation
std
::
string
extra_repr
()
const
;
protected:
Size
tp_rank_
=
0
;
Size
tp_size_
=
1
;
};
class
RowParallelLinear
:
public
BaseLinear
{
public:
RowParallelLinear
(
size_t
in_features
,
size_t
out_features
,
bool
bias
=
true
,
const
DataType
&
dtype
=
DataType
::
F32
,
const
Device
&
device
=
Device
(),
Size
tp_rank
=
0
,
Size
tp_size
=
1
,
infinicclComm_t
communicator
=
nullptr
);
// Forward pass: output = input @ weight.T + bias
Tensor
forward
(
Tensor
&
input
)
const
;
// String representation
std
::
string
extra_repr
()
const
;
protected:
Size
tp_rank_
=
0
;
Size
tp_size_
=
1
;
infinicclComm_t
communicator_
;
};
}
// namespace infinicore::nn
src/infinicore-test/test_nn_module.cc
View file @
92472152
...
@@ -297,6 +297,161 @@ TestResult NNModuleTest::testTensorParallelParameters() {
...
@@ -297,6 +297,161 @@ TestResult NNModuleTest::testTensorParallelParameters() {
});
});
}
}
TestResult
NNModuleTest
::
testParalleLinear
()
{
return
measureTime
(
"ParalleLinear"
,
[
this
]()
{
try
{
spdlog
::
info
(
"=========================================="
);
spdlog
::
info
(
" Testing Tensor Parallel Linear "
);
spdlog
::
info
(
"=========================================="
);
auto
device
=
infinicore
::
context
::
getDevice
();
spdlog
::
info
(
"Test Tensor Parallel Linear"
);
spdlog
::
info
(
device
.
toString
());
auto
w_data
=
std
::
vector
<
float
>
(
32
*
64
);
auto
b_data
=
std
::
vector
<
float
>
(
32
);
for
(
size_t
i
=
0
;
i
<
32
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
64
;
++
j
)
{
w_data
[
i
*
64
+
j
]
=
static_cast
<
float
>
(
j
);
}
b_data
[
i
]
=
static_cast
<
float
>
(
i
);
}
{
spdlog
::
info
(
"Test tp_size=4 tp_dim=0"
);
Size
tp_size
=
4
;
// Size tp_dim = 0;
std
::
vector
<
std
::
unique_ptr
<
infinicore
::
nn
::
ColumnParallelLinear
>>
tp_modules
;
for
(
Size
tp_rank
=
0
;
tp_rank
<
tp_size
;
++
tp_rank
)
{
auto
module
=
std
::
make_unique
<
infinicore
::
nn
::
ColumnParallelLinear
>
(
64
,
32
,
true
,
DataType
::
F32
,
device
,
tp_rank
,
tp_size
);
tp_modules
.
push_back
(
std
::
move
(
module
));
}
// Verify each partition has correct shape
for
(
size_t
i
=
0
;
i
<
tp_modules
.
size
();
++
i
)
{
const
auto
&
weight
=
tp_modules
[
i
]
->
weight
();
const
auto
&
bias
=
tp_modules
[
i
]
->
bias
();
// Weight should be partitioned along output dimension (dim 0)
if
(
weight
->
shape
()
!=
std
::
vector
<
size_t
>
({
8
,
64
}))
{
// 32/4 = 8
spdlog
::
error
(
"TP rank {}: Weight shape mismatch. Expected [8, 64], got [{}]"
,
i
,
formatShape
(
weight
->
shape
()));
return
false
;
}
// Bias should be partitioned along output dimension
if
(
bias
->
shape
()
!=
std
::
vector
<
size_t
>
({
8
}))
{
// 32/4 = 8
spdlog
::
error
(
"TP rank {}: Bias shape mismatch. Expected [8], got [{}]"
,
i
,
formatShape
(
bias
->
shape
()));
return
false
;
}
spdlog
::
debug
(
"TP rank {}: weight shape [{}], bias shape [{}]"
,
i
,
formatShape
(
weight
->
shape
()),
formatShape
(
bias
->
shape
()));
tp_modules
[
i
]
->
load_parameter_from_blob
(
"weight"
,
w_data
.
data
());
tp_modules
[
i
]
->
load_parameter_from_blob
(
"bias"
,
b_data
.
data
());
auto
weight_loaded
=
infinicore
::
Tensor
::
from_blob
(
w_data
.
data
(),
{
32
,
64
},
infinicore
::
DataType
::
F32
,
infinicore
::
Device
::
cpu
())
->
narrow
({{
0
,
i
*
8
,
8
}})
->
to
(
device
);
// Narrow to get the partition
auto
bias_loaded
=
infinicore
::
Tensor
::
from_blob
(
b_data
.
data
(),
{
32
},
infinicore
::
DataType
::
F32
,
infinicore
::
Device
::
cpu
())
->
narrow
({{
0
,
i
*
8
,
8
}})
->
to
(
device
);
// Narrow to get the partition
if
(
!
tensorsAllClose
(
tp_modules
[
i
]
->
weight
(),
weight_loaded
,
1e-6
,
1e-6
))
{
spdlog
::
error
(
"TP rank {}: Weight values do not match after load_parameter_from_blob"
,
i
);
return
false
;
}
if
(
!
tensorsAllClose
(
tp_modules
[
i
]
->
bias
(),
bias_loaded
,
1e-6
,
1e-6
))
{
spdlog
::
error
(
"TP rank {}: Bias values do not match after load_parameter_from_blob"
,
i
);
return
false
;
}
}
}
{
spdlog
::
info
(
"Test tp_size=4 tp_dim=1"
);
Size
tp_size
=
4
;
// Size tp_dim = 1;
std
::
vector
<
std
::
unique_ptr
<
infinicore
::
nn
::
RowParallelLinear
>>
tp_modules
;
for
(
Size
tp_rank
=
0
;
tp_rank
<
tp_size
;
++
tp_rank
)
{
auto
module
=
std
::
make_unique
<
infinicore
::
nn
::
RowParallelLinear
>
(
64
,
32
,
true
,
DataType
::
F32
,
device
,
tp_rank
,
tp_size
);
tp_modules
.
push_back
(
std
::
move
(
module
));
}
// Verify each partition has correct shape
for
(
size_t
i
=
0
;
i
<
tp_modules
.
size
();
++
i
)
{
const
auto
&
weight
=
tp_modules
[
i
]
->
weight
();
const
auto
&
bias
=
tp_modules
[
i
]
->
bias
();
// Weight should be partitioned along output dimension (dim 0)
if
(
weight
->
shape
()
!=
std
::
vector
<
size_t
>
({
32
,
16
}))
{
// 64/4 = 16
spdlog
::
error
(
"TP rank {}: Weight shape mismatch. Expected [32, 16], got [{}]"
,
i
,
formatShape
(
weight
->
shape
()));
return
false
;
}
// Bias should be partitioned along output dimension
if
(
bias
->
shape
()
!=
std
::
vector
<
size_t
>
({
32
}))
{
// Bias not partitioned when tp_dim=1
spdlog
::
error
(
"TP rank {}: Bias shape mismatch. Expected [32], got [{}]"
,
i
,
formatShape
(
bias
->
shape
()));
return
false
;
}
spdlog
::
debug
(
"TP rank {}: weight shape [{}], bias shape [{}]"
,
i
,
formatShape
(
weight
->
shape
()),
formatShape
(
bias
->
shape
()));
;
tp_modules
[
i
]
->
load_parameter_from_blob
(
"weight"
,
w_data
.
data
());
tp_modules
[
i
]
->
load_parameter_from_blob
(
"bias"
,
b_data
.
data
());
auto
weight_loaded
=
infinicore
::
Tensor
::
from_blob
(
w_data
.
data
(),
{
32
,
64
},
infinicore
::
DataType
::
F32
,
infinicore
::
Device
::
cpu
())
->
narrow
({{
1
,
i
*
16
,
16
}})
->
to
(
device
);
// Narrow to get the partition
auto
bias_loaded
=
infinicore
::
Tensor
::
from_blob
(
b_data
.
data
(),
{
32
},
infinicore
::
DataType
::
F32
,
infinicore
::
Device
::
cpu
())
->
to
(
device
);
// Narrow to get the partition
if
(
!
tensorsAllClose
(
tp_modules
[
i
]
->
weight
(),
weight_loaded
,
1e-6
,
1e-6
))
{
spdlog
::
error
(
"TP rank {}: Weight values do not match after load_parameter_from_blob"
,
i
);
return
false
;
}
if
(
!
tensorsAllClose
(
tp_modules
[
i
]
->
bias
(),
bias_loaded
,
1e-6
,
1e-6
))
{
spdlog
::
error
(
"TP rank {}: Bias values do not match after load_parameter_from_blob"
,
i
);
return
false
;
}
}
}
spdlog
::
info
(
"=== All Tensor Parallel Linear Tests Passed ==="
);
return
true
;
}
catch
(
const
std
::
exception
&
e
)
{
spdlog
::
error
(
"Exception in testTensorParallelParameters: {}"
,
e
.
what
());
return
false
;
}
});
}
// Test 2: Advanced load state dict functionality (hierarchical modules)
// Test 2: Advanced load state dict functionality (hierarchical modules)
TestResult
NNModuleTest
::
testLoadStateDict
()
{
TestResult
NNModuleTest
::
testLoadStateDict
()
{
return
measureTime
(
"AdvancedLoadStateDict"
,
[
this
]()
{
return
measureTime
(
"AdvancedLoadStateDict"
,
[
this
]()
{
...
@@ -1894,6 +2049,7 @@ TestResult NNModuleTest::run() {
...
@@ -1894,6 +2049,7 @@ TestResult NNModuleTest::run() {
results
.
push_back
(
testBasicModuleCreation
());
// Merged: creation + parameters + state_dict + load
results
.
push_back
(
testBasicModuleCreation
());
// Merged: creation + parameters + state_dict + load
results
.
push_back
(
testTensorParallelParameters
());
// Tensor-parallel parameters
results
.
push_back
(
testTensorParallelParameters
());
// Tensor-parallel parameters
results
.
push_back
(
testParalleLinear
());
// ParalleLinear
results
.
push_back
(
testLoadStateDict
());
// Advanced: hierarchical modules
results
.
push_back
(
testLoadStateDict
());
// Advanced: hierarchical modules
results
.
push_back
(
testModuleHierarchy
());
// Demonstrates hierarchical construction
results
.
push_back
(
testModuleHierarchy
());
// Demonstrates hierarchical construction
results
.
push_back
(
testParameterLoading
());
// Blob loading
results
.
push_back
(
testParameterLoading
());
// Blob loading
...
...
src/infinicore-test/test_nn_module.h
View file @
92472152
...
@@ -90,6 +90,7 @@ public:
...
@@ -90,6 +90,7 @@ public:
private:
private:
TestResult
testBasicModuleCreation
();
// Merged: creation, parameters, state_dict, load_state_dict
TestResult
testBasicModuleCreation
();
// Merged: creation, parameters, state_dict, load_state_dict
TestResult
testTensorParallelParameters
();
// Module with tensor parallel parameters
TestResult
testTensorParallelParameters
();
// Module with tensor parallel parameters
TestResult
testParalleLinear
();
// Module with ColumnParallelLinear, RowParallelLinear
TestResult
testLoadStateDict
();
// Advanced: hierarchical modules
TestResult
testLoadStateDict
();
// Advanced: hierarchical modules
TestResult
testModuleHierarchy
();
// Demonstrates proper hierarchical construction pattern
TestResult
testModuleHierarchy
();
// Demonstrates proper hierarchical construction pattern
TestResult
testParameterLoading
();
// Test blob parameter loading
TestResult
testParameterLoading
();
// Test blob parameter loading
...
...
src/infinicore/nn/linear.cc
View file @
92472152
#include "infinicore/nn/linear.hpp"
#include "infinicore/nn/linear.hpp"
#include "../utils.hpp"
#include "infinicore/ops.hpp"
#include "infinicore/ops.hpp"
#include "infinicore/ops/linear.hpp"
#include <optional>
#include <spdlog/spdlog.h>
#include <spdlog/spdlog.h>
namespace
infinicore
::
nn
{
namespace
infinicore
::
nn
{
Linear
::
Linear
(
size_t
in_features
,
size_t
out_features
,
bool
bias
,
const
DataType
&
dtype
,
const
Device
&
device
)
BaseLinear
::
BaseLinear
(
size_t
in_features
,
size_t
out_features
,
bool
bias
,
const
DataType
&
dtype
,
const
Device
&
device
)
:
in_features_
(
in_features
),
:
in_features_
(
in_features
),
out_features_
(
out_features
),
out_features_
(
out_features
),
has_bias_
(
bias
),
has_bias_
(
bias
),
dtype_
(
dtype
)
{
dtype_
(
dtype
)
{
device_
=
device
;
device_
=
device
;
}
Tensor
BaseLinear
::
compute_linear
(
Tensor
&
input
)
const
{
// Ensure input is contiguous before creating views (required for matmul)
// This prevents hanging when input tensor has non-contiguous memory layout
Tensor
input_contiguous
=
input
->
is_contiguous
()
?
input
:
input
->
contiguous
();
// Use ops::linear_ directly to match Python backend's exact code path
// This ensures identical computation and numerical results
// Parameter inherits from Tensor, so we cast to Tensor explicitly
Tensor
weight_tensor
=
static_cast
<
const
Tensor
&>
(
weight_
);
std
::
optional
<
Tensor
>
bias_opt
=
has_bias_
?
std
::
make_optional
<
Tensor
>
(
static_cast
<
const
Tensor
&>
(
bias_
))
:
std
::
nullopt
;
auto
output
=
infinicore
::
op
::
linear
(
input_contiguous
->
contiguous
(),
weight_tensor
->
contiguous
(),
bias_opt
);
return
output
;
}
Tensor
BaseLinear
::
forward
(
Tensor
&
input
)
const
{
return
compute_linear
(
input
);
}
Tensor
BaseLinear
::
forward
(
Tensor
&
input
,
Tensor
&
residual
)
const
{
auto
output
=
compute_linear
(
input
);
// Add residual: output = output + residual
infinicore
::
op
::
add_
(
output
,
output
,
residual
);
return
output
;
}
}
// namespace infinicore::nn
namespace
infinicore
::
nn
{
Linear
::
Linear
(
size_t
in_features
,
size_t
out_features
,
bool
bias
,
const
DataType
&
dtype
,
const
Device
&
device
)
:
BaseLinear
(
in_features
,
out_features
,
bias
,
dtype
,
device_
)
{
device_
=
device
;
// Initialize parameters using macro
// Initialize parameters using macro
INFINICORE_NN_PARAMETER_INIT
(
weight
,
({
out_features
,
in_features
},
dtype_
,
device
));
INFINICORE_NN_PARAMETER_INIT
(
weight
,
({
out_features
,
in_features
},
dtype_
,
device
));
...
@@ -22,52 +66,104 @@ Linear::Linear(size_t in_features, size_t out_features, bool bias, const DataTyp
...
@@ -22,52 +66,104 @@ Linear::Linear(size_t in_features, size_t out_features, bool bias, const DataTyp
bias_
=
Parameter
();
// Default constructed empty parameter
bias_
=
Parameter
();
// Default constructed empty parameter
}
}
SPDLOG_DEBUG
(
"Created Linear module: in_features={}, out_features={}, bias={}, dtype={}"
,
// SPDLOG_DEBUG("Created Linear module: in_features={}, out_features={}, bias={}, dtype={}",
in_features
,
out_features
,
bias
,
static_cast
<
int
>
(
dtype_
));
// in_features, out_features, bias, static_cast<int>(dtype_));
}
Tensor
Linear
::
forward
(
Tensor
&
input
)
const
{
return
BaseLinear
::
forward
(
input
);
}
std
::
string
Linear
::
extra_repr
()
const
{
return
"Linear(in_features="
+
std
::
to_string
(
in_features_
)
+
", out_features="
+
std
::
to_string
(
out_features_
)
+
", bias="
+
(
has_bias_
?
"true"
:
"false"
)
+
", dtype="
+
std
::
to_string
(
static_cast
<
int
>
(
dtype_
))
+
")"
;
}
}
Tensor
Linear
::
compute_linear
(
Tensor
&
input
)
const
{
}
// namespace infinicore::nn
// Create output tensor with shape [batch_size, out_features]
auto
output_shape
=
input
->
shape
();
namespace
infinicore
::
nn
{
output_shape
[
output_shape
.
size
()
-
1
]
=
out_features_
;
auto
output
=
Tensor
::
empty
(
output_shape
,
input
->
dtype
(),
input
->
device
());
ColumnParallelLinear
::
ColumnParallelLinear
(
size_t
in_features
,
size_t
out_features
,
bool
bias
,
const
DataType
&
dtype
,
const
Device
&
device
,
// Transpose weight: [out_features, in_features] -> [in_features, out_features]
Size
tp_rank
,
Size
tp_size
)
auto
weight_t
=
weight_
->
permute
({
1
,
0
});
:
BaseLinear
(
in_features
,
out_features
,
bias
,
dtype
,
device_
),
tp_rank_
(
tp_rank
),
if
(
has_bias_
)
{
tp_size_
(
tp_size
)
{
// Broadcast bias to output shape
size_t
ndim_diff
=
output
->
ndim
()
-
1
;
device_
=
device
;
std
::
vector
<
Stride
>
strides
(
ndim_diff
,
0
);
strides
.
push_back
(
bias_
->
stride
(
0
));
// Initialize parameters using macro
auto
bias_view
=
bias_
->
as_strided
(
output
->
shape
(),
strides
);
INFINICORE_NN_PARAMETER_INIT
(
weight
,
({
out_features
,
in_features
},
dtype_
,
device
,
0
,
tp_rank_
,
tp_size_
));
// Compute matmul result separately, then add to output
infinicore
::
op
::
matmul_
(
output
,
input
,
weight_t
);
// Register bias parameter if requested
infinicore
::
op
::
add_
(
output
,
output
,
bias_view
);
if
(
bias
)
{
INFINICORE_NN_PARAMETER_INIT
(
bias
,
({
out_features
},
dtype_
,
device
,
0
,
tp_rank_
,
tp_size_
));
}
else
{
}
else
{
// No bias: just compute output = input @ weight_t
bias_
=
Parameter
();
// Default constructed empty parameter
infinicore
::
op
::
matmul_
(
output
,
input
,
weight_t
);
}
}
return
output
;
// SPDLOG_DEBUG("Created ColumnParallelLinear module: in_features={}, out_features={}, bias={}, dtype={}",
// in_features, out_features, bias, static_cast<int>(dtype_));
}
}
Tensor
Linear
::
forward
(
Tensor
&
input
)
const
{
Tensor
ColumnParallel
Linear
::
forward
(
Tensor
&
input
)
const
{
return
compute_linear
(
input
);
return
BaseLinear
::
forward
(
input
);
}
}
Tensor
Linear
::
forward
(
Tensor
&
input
,
Tensor
&
residual
)
const
{
std
::
string
ColumnParallelLinear
::
extra_repr
()
const
{
auto
output
=
compute_linear
(
input
);
return
"ColumnParallelLinear(in_features="
+
std
::
to_string
(
in_features_
)
+
", out_features="
+
std
::
to_string
(
out_features_
)
+
", bias="
+
(
has_bias_
?
"true"
:
"false"
)
+
", dtype="
+
std
::
to_string
(
static_cast
<
int
>
(
dtype_
))
+
")"
;
}
// Add residual: output = output + residual
}
// namespace infinicore::nn
infinicore
::
op
::
add_
(
output
,
output
,
residual
);
namespace
infinicore
::
nn
{
RowParallelLinear
::
RowParallelLinear
(
size_t
in_features
,
size_t
out_features
,
bool
bias
,
const
DataType
&
dtype
,
const
Device
&
device
,
Size
tp_rank
,
Size
tp_size
,
infinicclComm_t
communicator
)
:
BaseLinear
(
in_features
,
out_features
,
bias
,
dtype
,
device_
),
tp_rank_
(
tp_rank
),
tp_size_
(
tp_size
),
communicator_
(
communicator
)
{
device_
=
device
;
// Initialize parameters using macro
INFINICORE_NN_PARAMETER_INIT
(
weight
,
({
out_features
,
in_features
},
dtype_
,
device
,
1
,
tp_rank_
,
tp_size_
));
// Register bias parameter if requested
if
(
bias
&&
(
0
==
tp_rank_
))
{
INFINICORE_NN_PARAMETER_INIT
(
bias
,
({
out_features
},
dtype_
,
device
,
0
,
0
,
1
));
}
else
{
bias_
=
Parameter
();
// Default constructed empty parameter
}
// SPDLOG_DEBUG("Created RowParallelLinear module: in_features={}, out_features={}, bias={}, dtype={}",
// in_features, out_features, bias, static_cast<int>(dtype_));
}
Tensor
RowParallelLinear
::
forward
(
Tensor
&
input
)
const
{
auto
output
=
BaseLinear
::
forward
(
input
);
if
((
tp_size_
>
1
)
&&
(
communicator_
!=
nullptr
))
{
Size
count
=
output
->
numel
();
DataType
type
=
output
->
dtype
();
infinirtStream_t
stream
=
infinicore
::
context
::
getStream
();
INFINICORE_CHECK_ERROR
(
infinicclAllReduce
(
output
->
data
(),
output
->
data
(),
count
,
static_cast
<
infiniDtype_t
>
(
static_cast
<
int
>
(
type
)),
INFINICCL_SUM
,
communicator_
,
stream
));
INFINICORE_CHECK_ERROR
(
infinirtStreamSynchronize
(
stream
));
// RUN_INFINI(infinirtStreamSynchronize(stream));
}
return
output
;
return
output
;
}
}
std
::
string
Linear
::
extra_repr
()
const
{
std
::
string
RowParallel
Linear
::
extra_repr
()
const
{
return
"Linear(in_features="
+
std
::
to_string
(
in_features_
)
+
", out_features="
+
std
::
to_string
(
out_features_
)
+
", bias="
+
(
has_bias_
?
"true"
:
"false"
)
+
", dtype="
+
std
::
to_string
(
static_cast
<
int
>
(
dtype_
))
+
")"
;
return
"
RowParallel
Linear(in_features="
+
std
::
to_string
(
in_features_
)
+
", out_features="
+
std
::
to_string
(
out_features_
)
+
", bias="
+
(
has_bias_
?
"true"
:
"false"
)
+
", dtype="
+
std
::
to_string
(
static_cast
<
int
>
(
dtype_
))
+
")"
;
}
}
}
// namespace infinicore::nn
}
// namespace infinicore::nn
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment