Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
f1420d19
Unverified
Commit
f1420d19
authored
Jan 07, 2020
by
Zihao Ye
Committed by
GitHub
Jan 07, 2020
Browse files
[Refactor] Renaming class methods of sampler utilities to improve readability (#1180)
* upd * upd
parent
31911834
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
42 deletions
+44
-42
include/dgl/sample_utils.h
include/dgl/sample_utils.h
+29
-27
tests/cpp/test_sampler.cc
tests/cpp/test_sampler.cc
+15
-15
No files found.
include/dgl/sample_utils.h
View file @
f1420d19
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include "random.h"
#include "random.h"
namespace
dgl
{
namespace
dgl
{
namespace
utils
{
template
<
template
<
typename
Idx
,
typename
Idx
,
...
@@ -24,7 +25,7 @@ template <
...
@@ -24,7 +25,7 @@ template <
bool
replace
>
bool
replace
>
class
BaseSampler
{
class
BaseSampler
{
public:
public:
virtual
Idx
d
raw
()
{
virtual
Idx
D
raw
()
{
LOG
(
INFO
)
<<
"Not implemented yet."
;
LOG
(
INFO
)
<<
"Not implemented yet."
;
return
0
;
return
0
;
}
}
...
@@ -52,14 +53,14 @@ class AliasSampler: public BaseSampler<Idx, DType, replace> {
...
@@ -52,14 +53,14 @@ class AliasSampler: public BaseSampler<Idx, DType, replace> {
std
::
vector
<
bool
>
used
;
// indicate availability, activated when replace=false;
std
::
vector
<
bool
>
used
;
// indicate availability, activated when replace=false;
std
::
vector
<
Idx
>
id_mapping
;
// index mapping, activated when replace=false;
std
::
vector
<
Idx
>
id_mapping
;
// index mapping, activated when replace=false;
inline
Idx
m
ap
(
Idx
x
)
const
{
inline
Idx
M
ap
(
Idx
x
)
const
{
// Map consecutive indices to unused elements
if
(
replace
)
if
(
replace
)
return
x
;
return
x
;
else
else
return
id_mapping
[
x
];
return
id_mapping
[
x
];
}
}
void
rebuild
(
const
std
::
vector
<
DType
>&
prob
)
{
void
Reconstruct
(
const
std
::
vector
<
DType
>&
prob
)
{
// Reconstruct alias table
N
=
0
;
N
=
0
;
accum
=
0.
;
accum
=
0.
;
taken
=
0.
;
taken
=
0.
;
...
@@ -79,7 +80,7 @@ class AliasSampler: public BaseSampler<Idx, DType, replace> {
...
@@ -79,7 +80,7 @@ class AliasSampler: public BaseSampler<Idx, DType, replace> {
std
::
fill
(
U
.
begin
(),
U
.
end
(),
avg
);
// initialize U
std
::
fill
(
U
.
begin
(),
U
.
end
(),
avg
);
// initialize U
std
::
queue
<
std
::
pair
<
Idx
,
DType
>
>
under
,
over
;
std
::
queue
<
std
::
pair
<
Idx
,
DType
>
>
under
,
over
;
for
(
Idx
i
=
0
;
i
<
N
;
++
i
)
{
for
(
Idx
i
=
0
;
i
<
N
;
++
i
)
{
DType
p
=
prob
[
m
ap
(
i
)];
DType
p
=
prob
[
M
ap
(
i
)];
if
(
p
>
avg
)
if
(
p
>
avg
)
over
.
push
(
std
::
make_pair
(
i
,
p
));
over
.
push
(
std
::
make_pair
(
i
,
p
));
else
else
...
@@ -102,33 +103,33 @@ class AliasSampler: public BaseSampler<Idx, DType, replace> {
...
@@ -102,33 +103,33 @@ class AliasSampler: public BaseSampler<Idx, DType, replace> {
}
}
public:
public:
void
reinit_s
tate
(
const
std
::
vector
<
DType
>&
prob
)
{
void
ResetS
tate
(
const
std
::
vector
<
DType
>&
prob
)
{
used
.
resize
(
prob
.
size
());
used
.
resize
(
prob
.
size
());
if
(
!
replace
)
if
(
!
replace
)
_prob
=
prob
;
_prob
=
prob
;
std
::
fill
(
used
.
begin
(),
used
.
end
(),
false
);
std
::
fill
(
used
.
begin
(),
used
.
end
(),
false
);
rebuild
(
prob
);
Reconstruct
(
prob
);
}
}
explicit
AliasSampler
(
RandomEngine
*
re
,
const
std
::
vector
<
DType
>&
prob
)
:
re
(
re
)
{
explicit
AliasSampler
(
RandomEngine
*
re
,
const
std
::
vector
<
DType
>&
prob
)
:
re
(
re
)
{
reinit_s
tate
(
prob
);
ResetS
tate
(
prob
);
}
}
~
AliasSampler
()
{}
~
AliasSampler
()
{}
Idx
d
raw
()
{
Idx
D
raw
()
{
DType
avg
=
accum
/
N
;
DType
avg
=
accum
/
N
;
if
(
!
replace
)
{
if
(
!
replace
)
{
if
(
2
*
taken
>=
accum
)
if
(
2
*
taken
>=
accum
)
rebuild
(
_prob
);
Reconstruct
(
_prob
);
while
(
true
)
{
while
(
true
)
{
DType
dice
=
re
->
Uniform
<
DType
>
(
0
,
N
);
DType
dice
=
re
->
Uniform
<
DType
>
(
0
,
N
);
Idx
i
=
static_cast
<
Idx
>
(
dice
),
rst
;
Idx
i
=
static_cast
<
Idx
>
(
dice
),
rst
;
DType
p
=
(
dice
-
i
)
*
avg
;
DType
p
=
(
dice
-
i
)
*
avg
;
if
(
p
<=
U
[
m
ap
(
i
)])
{
if
(
p
<=
U
[
M
ap
(
i
)])
{
rst
=
m
ap
(
i
);
rst
=
M
ap
(
i
);
}
else
{
}
else
{
rst
=
m
ap
(
K
[
i
]);
rst
=
M
ap
(
K
[
i
]);
}
}
DType
cap
=
_prob
[
rst
];
DType
cap
=
_prob
[
rst
];
if
(
!
used
[
rst
])
{
if
(
!
used
[
rst
])
{
...
@@ -141,10 +142,10 @@ class AliasSampler: public BaseSampler<Idx, DType, replace> {
...
@@ -141,10 +142,10 @@ class AliasSampler: public BaseSampler<Idx, DType, replace> {
DType
dice
=
re
->
Uniform
<
DType
>
(
0
,
N
);
DType
dice
=
re
->
Uniform
<
DType
>
(
0
,
N
);
Idx
i
=
static_cast
<
Idx
>
(
dice
);
Idx
i
=
static_cast
<
Idx
>
(
dice
);
DType
p
=
(
dice
-
i
)
*
avg
;
DType
p
=
(
dice
-
i
)
*
avg
;
if
(
p
<=
U
[
m
ap
(
i
)])
if
(
p
<=
U
[
M
ap
(
i
)])
return
m
ap
(
i
);
return
M
ap
(
i
);
else
else
return
m
ap
(
K
[
i
]);
return
M
ap
(
K
[
i
]);
}
}
};
};
...
@@ -170,14 +171,14 @@ class CDFSampler: public BaseSampler<Idx, DType, replace> {
...
@@ -170,14 +171,14 @@ class CDFSampler: public BaseSampler<Idx, DType, replace> {
std
::
vector
<
bool
>
used
;
// indicate availability, activated when replace=false;
std
::
vector
<
bool
>
used
;
// indicate availability, activated when replace=false;
std
::
vector
<
Idx
>
id_mapping
;
// indicate index mapping, activated when replace=false;
std
::
vector
<
Idx
>
id_mapping
;
// indicate index mapping, activated when replace=false;
inline
Idx
m
ap
(
Idx
x
)
const
{
inline
Idx
M
ap
(
Idx
x
)
const
{
// Map consecutive indices to unused elements
if
(
replace
)
if
(
replace
)
return
x
;
return
x
;
else
else
return
id_mapping
[
x
];
return
id_mapping
[
x
];
}
}
void
rebuild
(
const
std
::
vector
<
DType
>&
prob
)
{
void
Reconstruct
(
const
std
::
vector
<
DType
>&
prob
)
{
// Reconstruct CDF
N
=
0
;
N
=
0
;
accum
=
0.
;
accum
=
0.
;
taken
=
0.
;
taken
=
0.
;
...
@@ -197,28 +198,28 @@ class CDFSampler: public BaseSampler<Idx, DType, replace> {
...
@@ -197,28 +198,28 @@ class CDFSampler: public BaseSampler<Idx, DType, replace> {
}
}
public:
public:
void
reinit_s
tate
(
const
std
::
vector
<
DType
>&
prob
)
{
void
ResetS
tate
(
const
std
::
vector
<
DType
>&
prob
)
{
used
.
resize
(
prob
.
size
());
used
.
resize
(
prob
.
size
());
if
(
!
replace
)
if
(
!
replace
)
_prob
=
prob
;
_prob
=
prob
;
std
::
fill
(
used
.
begin
(),
used
.
end
(),
false
);
std
::
fill
(
used
.
begin
(),
used
.
end
(),
false
);
rebuild
(
prob
);
Reconstruct
(
prob
);
}
}
explicit
CDFSampler
(
RandomEngine
*
re
,
const
std
::
vector
<
DType
>&
prob
)
:
re
(
re
)
{
explicit
CDFSampler
(
RandomEngine
*
re
,
const
std
::
vector
<
DType
>&
prob
)
:
re
(
re
)
{
reinit_s
tate
(
prob
);
ResetS
tate
(
prob
);
}
}
~
CDFSampler
()
{}
~
CDFSampler
()
{}
Idx
d
raw
()
{
Idx
D
raw
()
{
DType
eps
=
std
::
numeric_limits
<
DType
>::
min
();
DType
eps
=
std
::
numeric_limits
<
DType
>::
min
();
if
(
!
replace
)
{
if
(
!
replace
)
{
if
(
2
*
taken
>=
accum
)
if
(
2
*
taken
>=
accum
)
rebuild
(
_prob
);
Reconstruct
(
_prob
);
while
(
true
)
{
while
(
true
)
{
DType
p
=
std
::
max
(
re
->
Uniform
<
DType
>
(
0.
,
accum
),
eps
);
DType
p
=
std
::
max
(
re
->
Uniform
<
DType
>
(
0.
,
accum
),
eps
);
Idx
rst
=
m
ap
(
std
::
lower_bound
(
cdf
.
begin
(),
cdf
.
end
(),
p
)
-
cdf
.
begin
()
-
1
);
Idx
rst
=
M
ap
(
std
::
lower_bound
(
cdf
.
begin
(),
cdf
.
end
(),
p
)
-
cdf
.
begin
()
-
1
);
DType
cap
=
_prob
[
rst
];
DType
cap
=
_prob
[
rst
];
if
(
!
used
[
rst
])
{
if
(
!
used
[
rst
])
{
used
[
rst
]
=
true
;
used
[
rst
]
=
true
;
...
@@ -228,7 +229,7 @@ class CDFSampler: public BaseSampler<Idx, DType, replace> {
...
@@ -228,7 +229,7 @@ class CDFSampler: public BaseSampler<Idx, DType, replace> {
}
}
}
}
DType
p
=
std
::
max
(
re
->
Uniform
<
DType
>
(
0.
,
accum
),
eps
);
DType
p
=
std
::
max
(
re
->
Uniform
<
DType
>
(
0.
,
accum
),
eps
);
return
m
ap
(
std
::
lower_bound
(
cdf
.
begin
(),
cdf
.
end
(),
p
)
-
cdf
.
begin
()
-
1
);
return
M
ap
(
std
::
lower_bound
(
cdf
.
begin
(),
cdf
.
end
(),
p
)
-
cdf
.
begin
()
-
1
);
}
}
};
};
...
@@ -251,7 +252,7 @@ class TreeSampler: public BaseSampler<Idx, DType, replace> {
...
@@ -251,7 +252,7 @@ class TreeSampler: public BaseSampler<Idx, DType, replace> {
int64_t
N
,
num_leafs
;
int64_t
N
,
num_leafs
;
public:
public:
void
reinit_s
tate
(
const
std
::
vector
<
DType
>&
prob
)
{
void
ResetS
tate
(
const
std
::
vector
<
DType
>&
prob
)
{
std
::
fill
(
weight
.
begin
(),
weight
.
end
(),
0
);
std
::
fill
(
weight
.
begin
(),
weight
.
end
(),
0
);
for
(
int
i
=
0
;
i
<
prob
.
size
();
++
i
)
for
(
int
i
=
0
;
i
<
prob
.
size
();
++
i
)
weight
[
num_leafs
+
i
]
=
prob
[
i
];
weight
[
num_leafs
+
i
]
=
prob
[
i
];
...
@@ -265,10 +266,10 @@ class TreeSampler: public BaseSampler<Idx, DType, replace> {
...
@@ -265,10 +266,10 @@ class TreeSampler: public BaseSampler<Idx, DType, replace> {
num_leafs
*=
2
;
num_leafs
*=
2
;
N
=
num_leafs
*
2
;
N
=
num_leafs
*
2
;
weight
.
resize
(
N
);
weight
.
resize
(
N
);
reinit_s
tate
(
prob
);
ResetS
tate
(
prob
);
}
}
Idx
d
raw
()
{
Idx
D
raw
()
{
int64_t
cur
=
1
;
int64_t
cur
=
1
;
DType
p
=
re
->
Uniform
<
DType
>
(
0
,
weight
[
cur
]);
DType
p
=
re
->
Uniform
<
DType
>
(
0
,
weight
[
cur
]);
DType
accum
=
0.
;
DType
accum
=
0.
;
...
@@ -295,6 +296,7 @@ class TreeSampler: public BaseSampler<Idx, DType, replace> {
...
@@ -295,6 +296,7 @@ class TreeSampler: public BaseSampler<Idx, DType, replace> {
}
}
};
};
};
// namespace utils
};
// namespace dgl
};
// namespace dgl
#endif // DGL_SAMPLE_UTILS_H_
#endif // DGL_SAMPLE_UTILS_H_
tests/cpp/test_sampler.cc
View file @
f1420d19
...
@@ -21,19 +21,19 @@ void _TestWithReplacement(RandomEngine *re) {
...
@@ -21,19 +21,19 @@ void _TestWithReplacement(RandomEngine *re) {
prob
[
i
]
/=
accum
;
prob
[
i
]
/=
accum
;
auto
_check_given_sampler
=
[
n_categories
,
n_rolls
,
&
prob
](
auto
_check_given_sampler
=
[
n_categories
,
n_rolls
,
&
prob
](
BaseSampler
<
Idx
,
DType
,
true
>
*
s
)
{
utils
::
BaseSampler
<
Idx
,
DType
,
true
>
*
s
)
{
std
::
vector
<
Idx
>
counter
(
n_categories
,
0
);
std
::
vector
<
Idx
>
counter
(
n_categories
,
0
);
for
(
Idx
i
=
0
;
i
<
n_rolls
;
++
i
)
{
for
(
Idx
i
=
0
;
i
<
n_rolls
;
++
i
)
{
Idx
dice
=
s
->
d
raw
();
Idx
dice
=
s
->
D
raw
();
counter
[
dice
]
++
;
counter
[
dice
]
++
;
}
}
for
(
Idx
i
=
0
;
i
<
n_categories
;
++
i
)
for
(
Idx
i
=
0
;
i
<
n_categories
;
++
i
)
ASSERT_NEAR
(
static_cast
<
DType
>
(
counter
[
i
])
/
n_rolls
,
prob
[
i
],
1e-2
);
ASSERT_NEAR
(
static_cast
<
DType
>
(
counter
[
i
])
/
n_rolls
,
prob
[
i
],
1e-2
);
};
};
AliasSampler
<
Idx
,
DType
,
true
>
as
(
re
,
prob
);
utils
::
AliasSampler
<
Idx
,
DType
,
true
>
as
(
re
,
prob
);
CDFSampler
<
Idx
,
DType
,
true
>
cs
(
re
,
prob
);
utils
::
CDFSampler
<
Idx
,
DType
,
true
>
cs
(
re
,
prob
);
TreeSampler
<
Idx
,
DType
,
true
>
ts
(
re
,
prob
);
utils
::
TreeSampler
<
Idx
,
DType
,
true
>
ts
(
re
,
prob
);
_check_given_sampler
(
&
as
);
_check_given_sampler
(
&
as
);
_check_given_sampler
(
&
cs
);
_check_given_sampler
(
&
cs
);
_check_given_sampler
(
&
ts
);
_check_given_sampler
(
&
ts
);
...
@@ -57,16 +57,16 @@ void _TestWithoutReplacementOrder(RandomEngine *re) {
...
@@ -57,16 +57,16 @@ void _TestWithoutReplacementOrder(RandomEngine *re) {
std
::
vector
<
Idx
>
ground_truth
=
{
0
,
3
,
2
,
1
};
std
::
vector
<
Idx
>
ground_truth
=
{
0
,
3
,
2
,
1
};
auto
_check_given_sampler
=
[
&
ground_truth
](
auto
_check_given_sampler
=
[
&
ground_truth
](
BaseSampler
<
Idx
,
DType
,
false
>
*
s
)
{
utils
::
BaseSampler
<
Idx
,
DType
,
false
>
*
s
)
{
for
(
size_t
i
=
0
;
i
<
ground_truth
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
ground_truth
.
size
();
++
i
)
{
Idx
dice
=
s
->
d
raw
();
Idx
dice
=
s
->
D
raw
();
ASSERT_EQ
(
dice
,
ground_truth
[
i
]);
ASSERT_EQ
(
dice
,
ground_truth
[
i
]);
}
}
};
};
AliasSampler
<
Idx
,
DType
,
false
>
as
(
re
,
prob
);
utils
::
AliasSampler
<
Idx
,
DType
,
false
>
as
(
re
,
prob
);
CDFSampler
<
Idx
,
DType
,
false
>
cs
(
re
,
prob
);
utils
::
CDFSampler
<
Idx
,
DType
,
false
>
cs
(
re
,
prob
);
TreeSampler
<
Idx
,
DType
,
false
>
ts
(
re
,
prob
);
utils
::
TreeSampler
<
Idx
,
DType
,
false
>
ts
(
re
,
prob
);
_check_given_sampler
(
&
as
);
_check_given_sampler
(
&
as
);
_check_given_sampler
(
&
cs
);
_check_given_sampler
(
&
cs
);
_check_given_sampler
(
&
ts
);
_check_given_sampler
(
&
ts
);
...
@@ -92,19 +92,19 @@ void _TestWithoutReplacementUnique(RandomEngine *re) {
...
@@ -92,19 +92,19 @@ void _TestWithoutReplacementUnique(RandomEngine *re) {
likelihood
.
push_back
(
re
->
Uniform
<
DType
>
());
likelihood
.
push_back
(
re
->
Uniform
<
DType
>
());
auto
_check_given_sampler
=
[
N
](
auto
_check_given_sampler
=
[
N
](
BaseSampler
<
Idx
,
DType
,
false
>
*
s
)
{
utils
::
BaseSampler
<
Idx
,
DType
,
false
>
*
s
)
{
std
::
vector
<
int
>
cnt
(
N
,
0
);
std
::
vector
<
int
>
cnt
(
N
,
0
);
for
(
Idx
i
=
0
;
i
<
N
;
++
i
)
{
for
(
Idx
i
=
0
;
i
<
N
;
++
i
)
{
Idx
dice
=
s
->
d
raw
();
Idx
dice
=
s
->
D
raw
();
cnt
[
dice
]
++
;
cnt
[
dice
]
++
;
}
}
for
(
Idx
i
=
0
;
i
<
N
;
++
i
)
for
(
Idx
i
=
0
;
i
<
N
;
++
i
)
ASSERT_EQ
(
cnt
[
i
],
1
);
ASSERT_EQ
(
cnt
[
i
],
1
);
};
};
AliasSampler
<
Idx
,
DType
,
false
>
as
(
re
,
likelihood
);
utils
::
AliasSampler
<
Idx
,
DType
,
false
>
as
(
re
,
likelihood
);
CDFSampler
<
Idx
,
DType
,
false
>
cs
(
re
,
likelihood
);
utils
::
CDFSampler
<
Idx
,
DType
,
false
>
cs
(
re
,
likelihood
);
TreeSampler
<
Idx
,
DType
,
false
>
ts
(
re
,
likelihood
);
utils
::
TreeSampler
<
Idx
,
DType
,
false
>
ts
(
re
,
likelihood
);
_check_given_sampler
(
&
as
);
_check_given_sampler
(
&
as
);
_check_given_sampler
(
&
cs
);
_check_given_sampler
(
&
cs
);
_check_given_sampler
(
&
ts
);
_check_given_sampler
(
&
ts
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment