svm_rank.py 4.95 KB
Newer Older
Davis King's avatar
Davis King committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#!/usr/bin/python
# The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
#
# 
# This is an example illustrating the use of the SVM-Rank tool from the dlib C++
# Library.  This is a tool useful for learning to rank objects.  For example,
# you might use it to learn to rank web pages in response to a user's query.
# The idea being to rank the most relevant pages higher than non-relevant pages.
# 
# In this example, we will create a simple test dataset and show how to learn a
# ranking function from it.  The purpose of the function will be to give
# "relevant" objects higher scores than "non-relevant" objects.  The idea is
# that you use this score to order the objects so that the most relevant objects
# come to the top of the ranked list.
#
# COMPILING THE DLIB PYTHON INTERFACE
#   You need to compile the dlib python interface before you can use this file.
#   To do this, run compile_dlib_python_module.bat.  This should work on any
#   operating system so long as you have CMake and boost-python installed.  On
#   Ubuntu, this can be done easily by running the command: 
#   sudo apt-get install libboost-python-dev cmake


import dlib


# Now lets make some testing data.  To make it really simple, lets suppose that
# we are ranking 2D vectors and that vectors with positive values in the first
# dimension should rank higher than other vectors.  So what we do is make
# examples of relevant (i.e. high ranking) and non-relevant (i.e. low ranking)
# vectors and store them into a ranking_pair object like so:

data = dlib.ranking_pair()
data.relevant.append(dlib.vector([1, 0]))
data.nonrelevant.append(dlib.vector([0, 1]))

# Now that we have some data, we can use a machine learning method to learn a
# function that will give high scores to the relevant vectors and low scores to
# the non-relevant vectors.
trainer = dlib.svm_rank_trainer()
# Note that the trainer object has some parameters that control how it behaves.
# For example, since this is the SVM-Rank algorithm it has a C parameter that
# controls the trade-off between trying to fit the training data exactly or
# selecting a "simpler" solution which might generalize better. 
trainer.c = 10

# So lets do the training.
rank = trainer.train(data)

# Now if you call rank on a vector it will output a ranking score.  In
# particular, the ranking score for relevant vectors should be larger than the
# score for non-relevant vectors.  
print "ranking score for a relevant vector:     ", rank(data.relevant[0])
print "ranking score for a non-relevant vector: ", rank(data.nonrelevant[0])
# These output the following:
#    ranking score for a relevant vector:     0.5
#    ranking score for a non-relevant vector: -0.5


# If we want an overall measure of ranking accuracy we can compute the ordering
# accuracy and mean average precision values by calling test_ranking_function().
# In this case, the ordering accuracy tells us how often a non-relevant vector
# was ranked ahead of a relevant vector.  In this case, it returns 1 for both
# metrics, indicating that the rank function outputs a perfect ranking.
print dlib.test_ranking_function(rank, data)

# We can also see the ranking weights:
print "weights: \n", rank.weights
# In this case they are:
#  0.5 
# -0.5 



# In the above example, our data contains just two sets of objects.  The
# relevant set and non-relevant set.  The trainer is attempting to find a
# ranking function that gives every relevant vector a higher score than every
# non-relevant vector.  Sometimes what you want to do is a little more complex
# than this. 
#
# For example, in the web page ranking example we have to rank pages based on a
# user's query.  In this case, each query will have its own set of relevant and
# non-relevant documents.  What might be relevant to one query may well be
# non-relevant to another.  So in this case we don't have a single global set of
# relevant web pages and another set of non-relevant web pages.  
#
# To handle cases like this, we can simply give multiple ranking_pair instances
# to the trainer.  Therefore, each ranking_pair would represent the
# relevant/non-relevant sets for a particular query.  An example is shown below
# (for simplicity, we reuse our data from above to make 4 identical "queries").

queries = dlib.ranking_pairs()
queries.append(data)
queries.append(data)
queries.append(data)
queries.append(data)

# We can train just as before.  
rank = trainer.train(queries)


# Now that we have multiple ranking_pair instances, we can also use
# cross_validate_ranking_trainer().  This performs cross-validation by splitting
# the queries up into folds.  That is, it lets the trainer train on a subset of
# ranking_pair instances and tests on the rest.  It does this over 4 different
# splits and returns the overall ranking accuracy based on the held out data.
# Just like test_ranking_function(), it reports both the ordering accuracy and
# mean average precision.
print "cross validation results: ", dlib.cross_validate_ranking_trainer(trainer, queries, 4)