fmriPredictor.h 12.6 KB
Newer Older
wangkx1's avatar
init  
wangkx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
/*! \file fmriPredictor.h
    \brief Contains declaration of class for making predictions about fmri data.

    \author Jesper Andersson
    \version 1.0b, Sep., 2012.
*/
// 
// fmmriPredictor.h
//
// Jesper Andersson, FMRIB Image Analysis Group
//
// Copyright (C) 2022 University of Oxford
//

#ifndef fmriPredictor_h
#define fmriPredictor_h

#include <cstdlib>
#include <string>
#include <vector>
#include <cmath>
#include <memory>
#include "armawrap/newmat.h"
#include "newimage/newimageall.h"
#include "miscmaths/miscmaths.h"
#include "EddyHelperClasses.h"
#include "DWIPredictionMaker.h"
#include "KMatrix.h"
#include "HyParEstimator.h"

namespace EDDY {
/****************************************************************//**
*
* \brief 
* 
*
* 
********************************************************************/
class fmriPredictor : public DWIPredictionMaker
{
public:
  /// Default Constructor
  fmriPredictor(const std::shared_ptr<const KMatrix>&        Kmat,
		const std::shared_ptr<const HyParEstimator>& hpe) EddyTry : _Kmats(1,Kmat->Clone()), _hpe(hpe->Clone()), _tr(-1.0) {} EddyCatch
  ~fmriPredictor() {}
  /// Returns prediction for point given by indx.
  virtual NEWIMAGE::volume<float> Predict(unsigned int indx, bool exclude=false) const;
  /// Returns prediction for point given by indx.
  virtual NEWIMAGE::volume<float> Predict(unsigned int indx, bool exclude=false);
  /// Returns prediction for point given by indx. This is used only as a means of directly comparing CPU and GPU outputs.
  virtual NEWIMAGE::volume<float> PredictCPU(unsigned int indx, bool exclude=false);
  /// Returns predictions for points given by indicies
  virtual std::vector<NEWIMAGE::volume<float> > Predict(const std::vector<unsigned int>& indicies, bool exclude=false);
  /// Returns input data for point given by indx
  virtual NEWIMAGE::volume<float> InputData(unsigned int indx) const;
  /// Returns input data for points given by indicies
  virtual std::vector<NEWIMAGE::volume<float> > InputData(const std::vector<unsigned int>& indicies) const;
  /// Returns variance of prediction for point given by indx.
  virtual double PredictionVariance(unsigned int indx, bool exclude=false);
  /// Returns measurement-error variance for point given by indx.
  virtual double ErrorVariance(unsigned int indx) const;
  /// Returns true if all data has been loaded
  virtual bool IsPopulated() const { return(is_populated()); }
  /// Indicates if it is ready to make predictions.
  virtual bool IsValid() const EddyTry { return(IsPopulated() && _lak && _Kmats[0]->IsValid()); } EddyCatch
  /// Set total number of scans to be loaded
  virtual void SetNoOfScans(unsigned int n);
  /// Set a point given by indx. This function is thread safe as long as different threads set different points.
  virtual void SetScan(const NEWIMAGE::volume<float>& scan, // _May_ be thread safe if used "sensibly"
		       const DiffPara&                dp,
		       unsigned int                   indx,
		       unsigned int                   sess);
  /// Set a point given by indx. This function is thread safe as long as different threads set different points.
  virtual void SetScan(const NEWIMAGE::volume<float>& scan, // _May_ be thread safe if used "sensibly"
		       const DiffPara&                dp,
		       unsigned int                   indx) EddyTry { SetScan(scan,dp,indx,0); } EddyCatch
  /// Returns the number of hyperparameters for the model
  virtual unsigned int NoOfHyperPar() const EddyTry { return(_Kmats[0]->NoOfHyperPar()); } EddyCatch
  /// Returns the hyperparameters for the model
  virtual std::vector<double> GetHyperPar() const EddyTry { return(_Kmats[0]->GetHyperPar()); } EddyCatch
  /// Evaluates the model so as to make the predictor ready to make predictions.
  virtual void EvaluateModel(const NEWIMAGE::volume<float>& mask, float fwhm, bool verbose=false);
  /// Evaluates the model so as to make the predictor ready to make predictions.
  virtual void EvaluateModel(const NEWIMAGE::volume<float>& mask, bool verbose=false) EddyTry { EvaluateModel(mask,0.0,verbose); } EddyCatch
  /// Writes internal content to disk for debug purposes
  virtual void WriteImageData(const std::string& fname) const;
  virtual void WriteMetaData(const std::string& fname) const;
  virtual void Write(const std::string& fname) const EddyTry { WriteImageData(fname); WriteMetaData(fname); } EddyCatch
private:
  // Couple of helper classes
  class sess_index {
  public:
    sess_index() : _sess(-1), _sindx(-1) {}
    int _sess;   // Session
    int _sindx;  // Index within session
  };
  class ptr_index_list {
  public:
    void PushBack(const NEWIMAGE::volume<float>& scan,
		  unsigned int                   indx) EddyTry {
      _sptr_lst.push_back(std::make_shared<NEWIMAGE::volume<float> >(scan));
      _gndx_lst.push_back(static_cast<int>(indx));
    } EddyCatch
    unsigned int Size() const { return(_sptr_lst.size()); }
    std::shared_ptr<NEWIMAGE::volume<float> > SPtr(unsigned int i) const EddyTry {
      if (i >= _sptr_lst.size()) throw EddyException("fmriPredictor::Sptr: Index out of bounds");
      return(_sptr_lst[i]);
    } EddyCatch
    std::shared_ptr<NEWIMAGE::volume<float> >& SPtr(unsigned int i) EddyTry {
      if (i >= _sptr_lst.size()) throw EddyException("fmriPredictor::Sptr: Index out of bounds");
      return(_sptr_lst[i]);
    } EddyCatch
    int Gndx(unsigned int i) const EddyTry {
      if (i >= _sptr_lst.size()) throw EddyException("fmriPredictor::Gndx: Index out of bounds");
      return(_gndx_lst[i]);
    } EddyCatch
    const std::vector<std::shared_ptr<NEWIMAGE::volume<float> > >& SPtr_list() const EddyTry { return(_sptr_lst); } EddyCatch
    void SortByGndx(); // Sort both _sptr_lst and _gndx_lst in ascending order of gndx
  private:
    std::vector<std::shared_ptr<NEWIMAGE::volume<float> > > _sptr_lst;
    std::vector<int>                                        _gndx_lst;
  };
  // Member variables
  std::vector<sess_index>                                 _glist;   /// List indexed with "global" index
  std::vector<ptr_index_list>                             _slist;   /// List indexed with session index
  std::vector<std::shared_ptr<KMatrix> >                  _Kmats;   /// K-matrices
  std::shared_ptr<HyParEstimator>                         _hpe;
  std::vector<std::shared_ptr<NEWIMAGE::volume<float> > > _mptrs;   /// Pointers to mean images
  double                                                  _tr;      /// Repetition time
  bool                                                    _lak;     /// Lists Are Kosher?
  std::mutex                                              _set_mut; /// Mutex for SetNoOfScans and SetScan 
  // Private utility functions
  void mean_correct();
  bool is_populated() const;
  /// Return session index for image with global index 'gindx'
  int session(int gindx) const EddyTry {
    if (!this->is_populated()) throw EddyException("fmriPredictor::no_of_scans_in_same_session: Predictor not yet populated");
    if (gindx >= _glist.size()) throw EddyException("fmriPredictor::no_of_scans_in_same_session: Index out of bounds");
    return(_glist[gindx]._sess);
  } EddyCatch
  /// Return true of all the global indicies in 'indicies' belong to the same session
  bool all_in_same_session(const std::vector<unsigned int>& indicies) const EddyTry {
    int sess = this->session(indicies[0]);
    for (unsigned int i=1; i<indicies.size(); i++) if (sess != this->session(indicies[i])) return(false);
    return(true);
  } EddyCatch
  /// Returns a ptr to first image. Useful for getting size and properties.
  std::shared_ptr<const NEWIMAGE::volume<float> > first_imptr() const EddyTry {
    if (!this->is_populated()) throw EddyException("fmriPredictor::first_imptr: Predictor not yet populated");
    return(_slist[0].SPtr(0));
  } EddyCatch
  /// No of scans in same session as that given by global index gindx
  unsigned int no_of_scans_in_same_session(int gindx) const EddyTry {
    if (!this->is_populated()) throw EddyException("fmriPredictor::no_of_scans_in_same_session: Predictor not yet populated");
    if (gindx >= _glist.size()) throw EddyException("fmriPredictor::no_of_scans_in_same_session: Index out of bounds");
    return(_slist[_glist[gindx]._sess].Size());
  } EddyCatch
  /// Returns index within session for scan with global index "gindx"
  int index_in_session(int gindx) const EddyTry {
    if (!this->is_populated()) throw EddyException("fmriPredictor::index_in_session: Predictor not yet populated");
    if (gindx >= _glist.size()) throw EddyException("fmriPredictor::index_in_session: Index out of bounds");
    return(_glist[gindx]._sindx);
  } EddyCatch
  /// Returns image ptr #i from the same session as gindx
  std::shared_ptr<const NEWIMAGE::volume<float> > imptr_from_same_session(int gindx, int i) const EddyTry {
    if (!this->is_populated()) throw EddyException("fmriPredictor::gindx_and_imptr_from_same_session: Predictor not yet populated");
    if (gindx >= _glist.size()) throw EddyException("fmriPredictor::gindx_and_imptr_from_same_session: gindx out of bounds");
    if (i >= _slist[_glist[gindx]._sess].Size()) throw EddyException("fmriPredictor::gindx_and_imptr_from_same_session: i out of bounds");
    return(_slist[_glist[gindx]._sess].SPtr(i));
  } EddyCatch
  /// Returns mean_image ptr from the same session as gindx
  std::shared_ptr<const NEWIMAGE::volume<float> > meanptr_from_same_session(int gindx) const EddyTry {
    if (!this->is_populated()) throw EddyException("fmriPredictor::gindx_and_imptr_from_same_session: Predictor not yet populated");
    if (gindx >= _glist.size()) throw EddyException("fmriPredictor::gindx_and_imptr_from_same_session: gindx out of bounds");
    return(_mptrs[_glist[gindx]._sess]);
  } EddyCatch

  /*
  /// Returns all global indicies (including gindx) from the same session as gindx
  std::vector<int> global_indicies_in_same_session(int gindx) const EddyTry {
    if (!this->is_populated()) throw EddyException("fmriPredictor::global_indicies_in_same_session: Predictor not yet populated");
    if (gindx >= _glist.size()) throw EddyException("fmriPredictor::global_indicies_in_same_session: Index out of bounds");
    std::vector<int> rval(_slist[_glist[gindx]._sess].Size());
    for (int i=0; i<rval.size(); i++) rval[i] = _slist[_glist[gindx]._sess].Gndx(i);
    return(rval);
  } EddyCatch
  /// Returns global index and image ptr #i from the same session as gindx
  std::tuple<int, std::shared_ptr<NEWIMAGE::volume<float> > > gindx_and_imptr_from_same_session(int gindx, int i) EddyTry {
    if (!this->is_populated()) throw EddyException("fmriPredictor::gindx_and_imptr_from_same_session: Predictor not yet populated");
    if (gindx >= _glist.size()) throw EddyException("fmriPredictor::gindx_and_imptr_from_same_session: gindx out of bounds");
    if (i >= _slist[_glist[gindx]._sess].Size()) throw EddyException("fmriPredictor::gindx_and_imptr_from_same_session: i out of bounds");
    return(std::make_tuple(_slist[_glist[gindx]._sess].Gndx(i),_slist[_glist[gindx]._sess].Sptr(i)));
  }
  */

  /// Total number of scans
  unsigned int no_of_scans() const EddyTry { 
    if (!this->is_populated()) throw EddyException("fmriPredictor::no_of_scans: Predictor not yet populated");
    return(_glist.size()); 
  } EddyCatch
  /// Total number of scans in session
  unsigned int no_of_scans(unsigned int session) const EddyTry {
    if (!this->is_populated()) throw EddyException("fmriPredictor::no_of_scans: Predictor not yet populated");
    if (session >= _slist.size()) throw EddyException("fmriPredictor::no_of_scans: Index out of bounds");
    return(_slist[session].Size());
  } EddyCatch
  bool same_no_of_scans_in_all_sessions() const;
  bool lists_are_kosher();
  bool glist_has_duplicates() const;
  void predict_image_cpu(// Input
			 unsigned int             indx,
			 bool                     exclude,
			 const arma::rowvec&      pv,
			 // Output
			 NEWIMAGE::volume<float>& pi) const;
  void predict_images_cpu(// Input
			  const std::vector<unsigned int>&       indicies,
			  bool                                   exclude,
			  const std::vector<arma::rowvec>&       pvecs,
			  // Output
			  std::vector<NEWIMAGE::volume<float> >& pi) const;
  #ifdef COMPILE_GPU
  void predict_image_gpu(unsigned int             indx,
			 bool                     excl,
			 const arma::rowvec&      pv,
			 NEWIMAGE::volume<float>& ima) const;
  void predict_images_gpu(// Input
			  const std::vector<unsigned int>&       indicies,
			  bool                                   exclude,
			  const std::vector<arma::rowvec>&       pvecs,
			  // Output
			  std::vector<NEWIMAGE::volume<float> >& pi) const;
  #endif
  bool get_y(// Input	
	     unsigned int i, unsigned int j, unsigned int k, unsigned int indx, bool exclude,
	     // Output
	     arma::colvec&  y) const;
};

} // End namespace EDDY

#endif // End #ifndef fmriPredictor_h